diff options
author | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-10-26 18:00:11 +0200 |
---|---|---|
committer | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-10-26 18:00:11 +0200 |
commit | e85b70b4cd38d27c810385e1a79fdf98016f5f88 (patch) | |
tree | edbceac2ba8b25994772af4f3968b6c0a8a0ee82 | |
parent | 99838c7e719057aa42b68acad2290eb57bebe78e (diff) |
[server] Add CertificateRetrievalMode and start() method to start listening
-rw-r--r-- | src/server.rs | 434 |
1 files changed, 251 insertions, 183 deletions
diff --git a/src/server.rs b/src/server.rs index aeb91d1..c6de319 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,42 +1,38 @@ -use std::{ - collections::{HashMap}, - error::Error, - net::SocketAddr, - sync::{Arc}, - time::Duration, -}; +use std::{collections::HashMap, error::Error, net::SocketAddr, sync::Arc, time::Duration}; use bevy::prelude::*; use bytes::Bytes; -use futures::{ - sink::SinkExt, -}; +use futures::sink::SinkExt; use futures_util::StreamExt; use quinn::{Endpoint, NewConnection, ServerConfig}; +use rustls::Certificate; use serde::Deserialize; use tokio::{ runtime::Runtime, sync::{ broadcast::{self}, - mpsc::{ - self, - error::{TryRecvError}, - }, + mpsc::{self, error::TryRecvError}, }, - task::{ JoinSet}, + task::JoinSet, }; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; -use crate::{ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_MESSAGE_QUEUE_SIZE, DEFAULT_KILL_MESSAGE_QUEUE_SIZE}; +use crate::{ + ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, + DEFAULT_MESSAGE_QUEUE_SIZE, +}; pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100; - /// Connection event raised when a client just connected to the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionEvent { pub id: ClientId } +pub struct ConnectionEvent { + pub id: ClientId, +} /// ConnectionLost event raised when a client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionLostEvent { pub id: ClientId } +pub struct ConnectionLostEvent { + pub id: ClientId, +} #[derive(Debug)] pub(crate) enum InternalAsyncMessage { @@ -46,10 +42,14 @@ pub(crate) enum InternalAsyncMessage { #[derive(Debug, Clone)] pub(crate) enum InternalSyncMessage { + StartListening { + config: ServerConfigurationData, + cert_mode: CertificateRetrievalMode, + }, ClientConnectedAck(ClientId), } -#[derive(Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct ServerConfigurationData { host: String, port: u16, @@ -57,11 +57,7 @@ pub struct ServerConfigurationData { } impl ServerConfigurationData { - pub fn new( - host: String, - port: u16, - local_bind_host: String, - ) -> Self { + pub fn new(host: String, port: u16, local_bind_host: String) -> Self { Self { host, port, @@ -70,7 +66,6 @@ impl ServerConfigurationData { } } - #[derive(Debug)] pub struct ClientPayload { client_id: ClientId, @@ -79,7 +74,7 @@ pub struct ClientPayload { #[derive(Debug)] pub(crate) struct ClientConnection { - client_id: ClientId, + client_id: ClientId, sender: mpsc::Sender<Bytes>, close_sender: broadcast::Sender<()>, } @@ -92,15 +87,45 @@ pub struct Server { pub(crate) internal_sender: broadcast::Sender<InternalSyncMessage>, } +/// How the server should retrieve its certificate. +#[derive(Debug, Clone)] +pub enum CertificateRetrievalMode { + Provided(Certificate), + GenerateSelfSigned, + LoadCertFromFile(String), + LoadCertFromFileOrGenerateSelfSigned(String), +} + impl Server { + /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode] + pub fn start( + &self, + config: ServerConfigurationData, + cert_mode: CertificateRetrievalMode, + ) -> Result<(), QuinnetError> { + match self + .internal_sender + .send(InternalSyncMessage::StartListening { config, cert_mode }) + { + Ok(_) => Ok(()), + Err(_) => Err(QuinnetError::FullQueue), + } + } + pub fn disconnect_client(&mut self, client_id: ClientId) { match self.clients.remove(&client_id) { Some(client_connection) => { if let Err(_) = client_connection.close_sender.send(()) { - error!("Failed to close client streams & resources while disconnecting client {}", client_id) + error!( + "Failed to close client streams & resources while disconnecting client {}", + client_id + ) } - }, - None => error!("Failed to disconnect client {}, client not found", client_id), + } + None => error!( + "Failed to disconnect client {}, client not found", + client_id + ), } } @@ -127,7 +152,7 @@ impl Server { } } - pub fn send_group_message<'a, I: Iterator<Item=&'a ClientId>, T: serde::Serialize>( + pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>( &mut self, client_ids: I, message: T, @@ -139,7 +164,7 @@ impl Server { self.send_payload(*id, payload.clone()); } Ok(()) - }, + } Err(_) => Err(QuinnetError::Serialization), } } @@ -157,7 +182,10 @@ impl Server { pub fn broadcast_payload<T: Into<Bytes> + Clone>(&mut self, payload: T) { // TODO Fix: Error handling for (_, client_connection) in self.clients.iter() { - client_connection.sender.try_send(payload.clone().into()).unwrap(); + client_connection + .sender + .try_send(payload.clone().into()) + .unwrap(); } } @@ -183,25 +211,36 @@ impl Server { } /// Returns default server configuration along with its certificate. -#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 -fn configure_server(server_host: &String) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> { - let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap(); - let cert_der = cert.serialize_der().unwrap(); - let priv_key = rustls::PrivateKey(cert.serialize_private_key_der()); - let cert_chain = vec![rustls::Certificate(cert_der.clone())]; - - let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; - Arc::get_mut(&mut server_config.transport) - .unwrap() - .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); - - Ok((server_config, cert_der)) +fn configure_server( + server_host: &String, + cert_mode: CertificateRetrievalMode, +) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> { + match cert_mode { + CertificateRetrievalMode::Provided(_cert) => todo!(), + CertificateRetrievalMode::GenerateSelfSigned => { + let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = rustls::PrivateKey(cert.serialize_private_key_der()); + let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + + let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + Arc::get_mut(&mut server_config.transport) + .unwrap() + .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); + + Ok((server_config, cert_der)) + } + CertificateRetrievalMode::LoadCertFromFile(_) => todo!(), + CertificateRetrievalMode::LoadCertFromFileOrGenerateSelfSigned(_) => todo!(), + } } -fn start_server( - mut commands: Commands, - runtime: Res<Runtime>, - config: Res<ServerConfigurationData>, +async fn connections_listening_task( + config: ServerConfigurationData, + cert_mode: CertificateRetrievalMode, + to_sync_server: mpsc::Sender<InternalAsyncMessage>, + mut from_sync_server: broadcast::Receiver<InternalSyncMessage>, + from_clients_sender: mpsc::Sender<ClientPayload>, ) { // TODO Fix: handle unwraps let server_adr_str = format!("{}:{}", config.local_bind_host, config.port); @@ -211,153 +250,179 @@ fn start_server( .parse() .expect("Failed to parse server address"); - // TODO Security: Server certificate let (server_config, _server_cert) = - configure_server(&config.host).expect("Failed to configure server"); + configure_server(&config.host, cert_mode).expect("Failed to configure server"); + + let mut client_gen_id: ClientId = 0; + let mut client_id_mappings = HashMap::new(); + + let (_endpoint, mut incoming) = + Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint"); + + // Start iterating over incoming connections/clients. + while let Some(conn) = incoming.next().await { + let mut new_connection: NewConnection = + conn.await.expect("Failed to handle incoming connection"); + let connection = new_connection.connection; + + // Attribute an id to this client + client_gen_id += 1; // TODO Fix: Better id generation/check + let client_id = client_gen_id; + client_id_mappings.insert(connection.stable_id(), client_id); + + info!( + "New connection from {}, client_id: {}, stable_id : {}", + connection.remote_address(), + client_id, + connection.stable_id() + ); + + // Create a close channel for this client + let (close_sender, mut close_receiver): ( + tokio::sync::broadcast::Sender<()>, + tokio::sync::broadcast::Receiver<()>, + ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + // Create an ordered reliable send channel for this client + let (to_client_sender, mut to_client_receiver) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + let send_stream = connection.open_uni().await.expect( + format!( + "Failed to open unidirectional send stream for client: {}", + client_id + ) + .as_str(), + ); + + let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); + + let to_sync_server_clone = to_sync_server.clone(); + let close_sender_clone = close_sender.clone(); + let _network_broadcaster = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("Unidirectional send stream forced to disconnected for client: {}", client_id) + } + _ = async { + while let Some(msg_bytes) = to_client_receiver.recv().await { + // TODO Perf: Batch frames for a send_all + // TODO Clean: Error handling + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while sending to client {}: {}", client_id, err); + error!("Client {} seems disconnected, closing resources", client_id); + if let Err(_) = close_sender_clone.send(()) { + error!("Failed to close all client streams & resources for client {}", client_id) + } + to_sync_server_clone.send( + InternalAsyncMessage::ClientLostConnection(client_id)) + .await + .expect("Failed to signal connection lost to sync server"); + }; + } + } => {} + } + }); + + // Signal the sync server of this new connection + // let mut sync_msg_receiver = to_async_server.subscribe(); + to_sync_server + .send(InternalAsyncMessage::ClientConnected(ClientConnection { + client_id: client_id, + sender: to_client_sender, + close_sender: close_sender.clone(), + })) + .await + .expect("Failed to signal connection to sync client"); + + // Wait for the sync server to acknowledge the connection before spawning reception tasks. + while let Ok(message) = from_sync_server.recv().await { + match message { + InternalSyncMessage::ClientConnectedAck(id) => { + if id == client_id { + break; + } + } + _ => {} + } + } + + // Spawn a task to listen for stream opened from this client + let from_client_sender_clone = from_clients_sender.clone(); + let mut uni_receivers: JoinSet<()> = JoinSet::new(); + let mut close_receiver = close_sender.subscribe(); + let _client_receiver = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("New Stream listener forced to disconnected for client: {}", client_id) + } + _ = async { + // For each new stream opened by the client + while let Some(Ok(recv)) = new_connection.uni_streams.next().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + + // Spawn a task to receive data on this stream. + let from_client_sender = from_client_sender_clone.clone(); + uni_receivers.spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_client_sender + .send(ClientPayload { + client_id: client_id, + msg: msg_bytes.into(), + }) + .await + .unwrap();// TODO Fix: error event + } + trace!("Unidirectional stream receiver ended for client: {}", client_id) + }); + } + } => { + trace!("New Stream listener ended for client: {}", client_id) + } + } + uni_receivers.shutdown().await; + trace!( + "All unidirectional stream receivers cleaned for client: {}", + client_id + ) + }); + } +} + +fn start_async_server(mut commands: Commands, runtime: Res<Runtime>) { // TODO Clean: Configure size let (from_clients_sender, from_clients_receiver) = mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); let (to_sync_server, from_async_server) = mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_server, _) = + let (to_async_server, mut from_sync_server) = broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); commands.insert_resource(Server { clients: HashMap::new(), receiver: from_clients_receiver, internal_receiver: from_async_server, - internal_sender: to_async_server.clone() + internal_sender: to_async_server.clone(), }); - - // Create server task + // Create async server task runtime.spawn(async move { - let mut client_gen_id: ClientId = 0; - let mut client_id_mappings = HashMap::new(); - - let (_endpoint, mut incoming) = - Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint"); - - // Start iterating over incoming connections/clients. - while let Some(conn) = incoming.next().await { - let mut new_connection: NewConnection = - conn.await.expect("Failed to handle incoming connection"); - let connection = new_connection.connection; - - // Attribute an id to this client - client_gen_id += 1; // TODO Fix: Better id generation/check - let client_id = client_gen_id; - client_id_mappings.insert(connection.stable_id(), client_id); - - info!( - "New connection from {}, client_id: {}, stable_id : {}", - connection.remote_address(), - client_id, - connection.stable_id() - ); - - // Create a close channel for this client - let (close_sender, mut close_receiver): ( - tokio::sync::broadcast::Sender<()>, - tokio::sync::broadcast::Receiver<()>, - ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - - // Create an ordered reliable send channel for this client - let (to_client_sender, mut to_client_receiver) = - mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - - let send_stream = connection - .open_uni() - .await - .expect( format!("Failed to open unidirectional send stream for client: {}", client_id).as_str()); - - let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); - - let to_sync_server_clone = to_sync_server.clone(); - let close_sender_clone = close_sender.clone(); - let _network_broadcaster = tokio::spawn(async move { - tokio::select! { - _ = close_receiver.recv() => { - trace!("Unidirectional send stream forced to disconnected for client: {}", client_id) - } - _ = async { - while let Some(msg_bytes) = to_client_receiver.recv().await { - // TODO Perf: Batch frames for a send_all - // TODO Clean: Error handling - if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { - error!("Error while sending to client {}: {}", client_id, err); - error!("Client {} seems disconnected, closing resources", client_id); - if let Err(_) = close_sender_clone.send(()) { - error!("Failed to close all client streams & resources for client {}", client_id) - } - to_sync_server_clone.send( - InternalAsyncMessage::ClientLostConnection(client_id)) - .await - .expect("Failed to signal connection lost to sync server"); - }; - } - } => {} - } - }); - - // Signal the sync server of this new connection - let mut sync_msg_receiver = to_async_server.subscribe(); - to_sync_server.send( - InternalAsyncMessage::ClientConnected(ClientConnection { - client_id: client_id, - sender: to_client_sender, - close_sender: close_sender.clone() - })) - .await - .expect("Failed to signal connection to sync client"); - - // Wait for the sync server to acknowledge the connection before spawning reception tasks. - while let Ok(message) = sync_msg_receiver.recv().await { - match message { - InternalSyncMessage::ClientConnectedAck(id) =>{ - if id == client_id {break;} - }, + // Wait for the sync server to to signal us to start listening + while let Ok(message) = from_sync_server.recv().await { + match message { + InternalSyncMessage::StartListening { config, cert_mode } => { + connections_listening_task( + config, + cert_mode, + to_sync_server.clone(), + to_async_server.subscribe(), + from_clients_sender.clone(), + ) + .await; } + _ => {} } - - // Spawn a task to listen for stream opened from this client - let from_client_sender_clone = from_clients_sender.clone(); - let mut uni_receivers:JoinSet<()> = JoinSet::new(); - let mut close_receiver = close_sender.subscribe(); - let _client_receiver = tokio::spawn(async move { - tokio::select! { - _ = close_receiver.recv() => { - trace!("New Stream listener forced to disconnected for client: {}", client_id) - } - _ = async { - // For each new stream opened by the client - while let Some(Ok(recv)) = new_connection.uni_streams.next().await { - let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - - // Spawn a task to receive data on this stream. - let from_client_sender = from_client_sender_clone.clone(); - uni_receivers.spawn(async move { - while let Some(Ok(msg_bytes)) = frame_recv.next().await { - from_client_sender - .send(ClientPayload { - client_id: client_id, - msg: msg_bytes.into(), - }) - .await - .unwrap();// TODO Fix: error event - } - trace!("Unidirectional stream receiver ended for client: {}", client_id) - }); - } - } => { - trace!("New Stream listener ended for client: {}", client_id) - } - } - uni_receivers.shutdown().await; - trace!("All unidirectional stream receivers cleaned for client: {}", client_id) - }); } }); } @@ -366,20 +431,23 @@ fn start_server( fn update_sync_server( mut server: ResMut<Server>, mut connection_events: EventWriter<ConnectionEvent>, - mut connection_lost_events: EventWriter<ConnectionLostEvent> + mut connection_lost_events: EventWriter<ConnectionLostEvent>, ) { while let Ok(message) = server.internal_receiver.try_recv() { match message { InternalAsyncMessage::ClientConnected(connection) => { let id = connection.client_id; server.clients.insert(id, connection); - server.internal_sender.send(InternalSyncMessage::ClientConnectedAck(id)).unwrap(); + server + .internal_sender + .send(InternalSyncMessage::ClientConnectedAck(id)) + .unwrap(); connection_events.send(ConnectionEvent { id: id }); } InternalAsyncMessage::ClientLostConnection(client_id) => { server.clients.remove(&client_id); connection_lost_events.send(ConnectionLostEvent { id: client_id }); - }, + } } } } @@ -402,7 +470,7 @@ impl Plugin for QuinnetServerPlugin { ) .add_event::<ConnectionEvent>() .add_event::<ConnectionLostEvent>() - .add_startup_system_to_stage(StartupStage::PreStartup, start_server) - .add_system_to_stage(CoreStage::PreUpdate,update_sync_server); + .add_startup_system_to_stage(StartupStage::PreStartup, start_async_server) + .add_system_to_stage(CoreStage::PreUpdate, update_sync_server); } } |