diff options
Diffstat (limited to 'src/server.rs')
-rw-r--r-- | src/server.rs | 276 |
1 files changed, 276 insertions, 0 deletions
diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..9fa010d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,276 @@ +use std::{ + collections::{HashMap}, + error::Error, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; + +use bevy::prelude::*; +use bytes::Bytes; +use futures::{ + sink::SinkExt, +}; +use futures_util::StreamExt; +use quinn::{Endpoint, NewConnection, ServerConfig}; +use serde::Deserialize; +use tokio::{ + runtime::Runtime, + sync::{ + broadcast::{self, error::SendError}, + mpsc::{ + self, + error::{TryRecvError}, + }, + }, + task::{ JoinSet}, +}; +use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; + +use crate::{ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_MESSAGE_QUEUE_SIZE, DEFAULT_KILL_MESSAGE_QUEUE_SIZE}; + +#[derive(Deserialize)] +pub struct ServerConfigurationData { + host: String, + port: u16, + local_bind_host: String, +} + +#[derive(Debug)] +pub struct ClientPayload { + client_id: ClientId, + msg: Bytes, +} + +pub struct Server { + broadcast_sender: broadcast::Sender<Bytes>, + receiver: mpsc::Receiver<ClientPayload>, + close_senders: Arc<Mutex<HashMap<ClientId, broadcast::Sender<()>>>>, +} + +impl Server { + pub fn disconnect_client(&mut self, client_id: ClientId) { + match self.close_senders.lock() { + Ok(senders) => { + match senders.get(&client_id) { + Some(close_sender) => + if let Err(_) = close_sender.send(()) { + error!("Failed to close client streams & resources while disconnecting client {}", client_id) + } , + None => warn!("Tried to disconnect unknown client {}", client_id), + } + }, + Err(_) => error!("Failed to acquire lock while disconnecting client {}", client_id), + } + } + + pub fn receive_message<T: serde::de::DeserializeOwned>( + &mut self, + ) -> Result<Option<(T, ClientId)>, QuinnetError> { + match self.receive_payload()? { + Some(client_msg) => match bincode::deserialize(&client_msg.msg) { + Ok(msg) => Ok(Some((msg, client_msg.client_id))), + Err(_) => Err(QuinnetError::Deserialization), + }, + None => Ok(None), + } + } + + pub fn broadcast_message<T: serde::Serialize>( + &mut self, + message: T, + ) -> Result<(), QuinnetError> { + match bincode::serialize(&message) { + Ok(payload) => self.broadcast_payload(payload), + Err(_) => Err(QuinnetError::Serialization), + } + } + + pub fn broadcast_payload<T: Into<Bytes>>(&mut self, payload: T) -> Result<(), QuinnetError> { + match self.broadcast_sender.send(payload.into()) { + Ok(_) => Ok(()), + Err(err) => match err { + SendError(_) => Err(QuinnetError::ChannelClosed), + } + } + } + + //TODO Clean: Consider receiving payloads for a specified client + pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> { + match self.receiver.try_recv() { + Ok(msg) => Ok(Some(msg)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed), + }, + } + } +} + +pub struct QuinnetServerPlugin {} + +impl Default for QuinnetServerPlugin { + fn default() -> Self { + Self {} + } +} + +impl Plugin for QuinnetServerPlugin { + fn build(&self, app: &mut App) { + app.insert_resource( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ) + .add_startup_system_to_stage(StartupStage::PreStartup, start_server); + } +} + +/// Returns default server configuration along with its certificate. +#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 +fn configure_server(server_host: &String) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> { + let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = rustls::PrivateKey(cert.serialize_private_key_der()); + let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + + let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + Arc::get_mut(&mut server_config.transport) + .unwrap() + .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); + + Ok((server_config, cert_der)) +} + +fn start_server( + mut commands: Commands, + runtime: Res<Runtime>, + config: Res<ServerConfigurationData>, +) { + // TODO Fix: handle unwraps + let server_adr_str = format!("{}:{}", config.local_bind_host, config.port); + info!("Starting server on: {} ...", server_adr_str); + + let server_addr: SocketAddr = server_adr_str + .parse() + .expect("Failed to parse server address"); + + // TODO Security: Server certificate + let (server_config, server_cert) = + configure_server(&config.host).expect("Failed to configure server"); + + // TODO Clean: Configure size + let (from_clients_sender, from_clients_receiver) = + mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (broadcast_channel_sender, _) = + broadcast::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + + let close_senders = Arc::new(Mutex::new(HashMap::new())); + commands.insert_resource(Server { + broadcast_sender: broadcast_channel_sender.clone(), + receiver: from_clients_receiver, + close_senders: close_senders.clone() + }); + + // Create server task + runtime.spawn(async move { + let mut client_gen_id: ClientId = 0; + let mut client_id_mappings = HashMap::new(); + + let (_endpoint, mut incoming) = + Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint"); + + // Start iterating over incoming connections/clients. + while let Some(conn) = incoming.next().await { + let mut new_connection: NewConnection = + conn.await.expect("Failed to handle incoming connection"); + let connection = new_connection.connection; + + // Attribute an id to this client + client_gen_id += 1; // TODO Fix: Better id generation/check + let client_id = client_gen_id; + client_id_mappings.insert(connection.stable_id(), client_id); + // TODO Clean: Raise a connection event to the sync side. + + info!( + "New connection from {}, client_id: {}, stable_id : {}", + connection.remote_address(), + client_id, + connection.stable_id() + ); + + // Create a close channel for this client + let (close_sender, mut close_receiver): ( + tokio::sync::broadcast::Sender<()>, + tokio::sync::broadcast::Receiver<()>, + ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + // Create the broadcast stream + let send_stream = connection + .open_uni() + .await + .expect( format!("Failed to open broadcast send stream for client: {}", client_id).as_str()); + let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); + let mut broadcast_channel_receiver = broadcast_channel_sender.subscribe(); + let _network_broadcaster = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("Broadcaster forced to disconnected for client: {}", client_id) + } + _ = async { + while let Ok(msg_bytes) = broadcast_channel_receiver.recv().await { + // TODO Perf: Batch frames for a send_all + // TODO Clean: Error handling + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while broadcasting to client {}: {}", client_id, err); + }; + } + } => {} + } + }); + + // Spawn a task to listen for stream opened from this client + let from_client_sender_clone = from_clients_sender.clone(); + let mut uni_receivers:JoinSet<()> = JoinSet::new(); + let mut close_receiver = close_sender.subscribe(); + let _client_receiver = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("New Stream listener forced to disconnected for client: {}", client_id) + } + _ = async { + // For each new stream opened by the client + while let Some(Ok(recv)) = new_connection.uni_streams.next().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + + // Spawn a task to receive data on this stream. + let from_client_sender = from_client_sender_clone.clone(); + uni_receivers.spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_client_sender + .send(ClientPayload { + client_id: client_id, + msg: msg_bytes.into(), + }) + .await + .unwrap();// TODO Fix: error event + } + trace!("Unidirectional stream receiver ended for client: {}", client_id) + }); + } + } => { + trace!("New Stream listener ended for client: {}", client_id) + } + } + uni_receivers.shutdown().await; + trace!("All unidirectional stream receivers cleaned for client: {}", client_id) + }); + + { + let mut close_senders = close_senders.lock().unwrap(); + close_senders.insert(client_id, close_sender); + } + } + }); +} |