diff options
-rw-r--r-- | src/server.rs | 209 |
1 files changed, 46 insertions, 163 deletions
diff --git a/src/server.rs b/src/server.rs index 0ac916b..6729728 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,8 +7,7 @@ use std::{ use bevy::prelude::*; use bytes::Bytes; -use futures_util::StreamExt; -use quinn::{ConnectionError, Endpoint as QuinnEndpoint, RecvStream, ServerConfig}; +use quinn::{ConnectionError, Endpoint as QuinnEndpoint, ServerConfig}; use serde::Deserialize; use tokio::{ runtime, @@ -20,14 +19,14 @@ use tokio::{ }, }, }; -use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use crate::{ server::certificate::retrieve_certificate, shared::{ channel::{ - channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId, - ChannelSyncMessage, ChannelType, MultiChannelId, + channels_task, get_channel_id_from_type, reliable_receiver_task, + unreliable_receiver_task, Channel, ChannelAsyncMessage, ChannelId, ChannelSyncMessage, + ChannelType, MultiChannelId, }, AsyncRuntime, ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, @@ -98,30 +97,17 @@ impl ServerConfigurationData { } } -/// Represents a client message in its binary form -#[derive(Debug)] -pub struct ClientPayload { - /// Id of the client sending the message - client_id: ClientId, - /// Content of the message as bytes - msg: Bytes, -} - #[derive(Debug)] pub(crate) enum ServerAsyncMessage { ClientConnected(ClientConnection), ClientConnectionClosed(ClientId, ConnectionError), } -#[derive(Debug, Clone)] -pub(crate) enum ServerSyncMessage { - ClientConnectedAck(ClientId), -} - #[derive(Debug)] -pub(crate) struct ClientConnection { +pub struct ClientConnection { client_id: ClientId, channels: HashMap<ChannelId, Channel>, + bytes_from_client_recv: mpsc::Receiver<Bytes>, close_sender: broadcast::Sender<()>, pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, @@ -173,28 +159,35 @@ pub struct Endpoint { channels: HashSet<ChannelId>, default_channel: Option<ChannelId>, last_gen_id: MultiChannelId, - payloads_from_clients_recv: mpsc::Receiver<ClientPayload>, close_sender: broadcast::Sender<()>, pub(crate) from_async_server_recv: mpsc::Receiver<ServerAsyncMessage>, - pub(crate) to_async_server_send: broadcast::Sender<ServerSyncMessage>, } impl Endpoint { - pub fn receive_message<T: serde::de::DeserializeOwned>( + /// Returns an iterator over all client ids + pub fn clients(&self) -> Vec<ClientId> { + self.clients.keys().cloned().collect() + } + + pub fn receive_message_from<T: serde::de::DeserializeOwned>( &mut self, - ) -> Result<Option<(T, ClientId)>, QuinnetError> { - match self.receive_payload()? { - Some(client_msg) => match bincode::deserialize(&client_msg.msg) { - Ok(msg) => Ok(Some((msg, client_msg.client_id))), + client_id: ClientId, + ) -> Result<Option<T>, QuinnetError> { + match self.receive_payload_from(client_id)? { + Some(payload) => match bincode::deserialize(&payload) { + Ok(msg) => Ok(Some(msg)), Err(_) => Err(QuinnetError::Deserialization), }, None => Ok(None), } } - pub fn try_receive_message<T: serde::de::DeserializeOwned>(&mut self) -> Option<(T, ClientId)> { - match self.receive_message() { + pub fn try_receive_message_from<T: serde::de::DeserializeOwned>( + &mut self, + client_id: ClientId, + ) -> Option<T> { + match self.receive_message_from(client_id) { Ok(message) => message, Err(err) => { error!("try_receive_message: {}", err); @@ -203,18 +196,24 @@ impl Endpoint { } } - pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> { - match self.payloads_from_clients_recv.try_recv() { - Ok(msg) => Ok(Some(msg)), - Err(err) => match err { - TryRecvError::Empty => Ok(None), - TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), + pub fn receive_payload_from( + &mut self, + client_id: ClientId, + ) -> Result<Option<Bytes>, QuinnetError> { + match self.clients.get_mut(&client_id) { + Some(client) => match client.bytes_from_client_recv.try_recv() { + Ok(msg) => Ok(Some(msg)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), + }, }, + None => Err(QuinnetError::UnknownClient(client_id)), } } - pub fn try_receive_payload(&mut self) -> Option<ClientPayload> { - match self.receive_payload() { + pub fn try_receive_payload_from(&mut self, client_id: ClientId) -> Option<Bytes> { + match self.receive_payload_from(client_id) { Ok(payload) => payload, Err(err) => { error!("try_receive_payload: {}", err); @@ -548,16 +547,8 @@ impl Endpoint { connection.create_channel(*channel_id)?; } - match self - .to_async_server_send - .send(ServerSyncMessage::ClientConnectedAck(connection.client_id)) - { - Ok(_) => { - self.clients.insert(connection.client_id, connection); - Ok(()) - } - Err(_) => Err(QuinnetError::InternalChannelClosed), - } + self.clients.insert(connection.client_id, connection); + Ok(()) } } } @@ -607,12 +598,8 @@ impl Server { .ok_or(QuinnetError::LockAcquisitionFailure)? .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); - let (payloads_from_clients_send, payloads_from_clients_recv) = - mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); let (to_sync_server_send, from_async_server_recv) = mpsc::channel::<ServerAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_server_send, from_sync_server_recv) = - broadcast::channel::<ServerSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); let (endpoint_close_send, endpoint_close_recv) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); @@ -624,8 +611,6 @@ impl Server { server_addr, to_sync_server_send.clone(), endpoint_close_recv, - from_sync_server_recv, - payloads_from_clients_send.clone(), ) .await; }); @@ -635,10 +620,8 @@ impl Server { channels: HashSet::new(), default_channel: None, last_gen_id: 0, - payloads_from_clients_recv, close_sender: endpoint_close_send, from_async_server_recv, - to_async_server_send: to_async_server_send.clone(), }; let ordered_reliable_id = endpoint.open_default_channels()?; self.endpoint = Some(endpoint); @@ -673,8 +656,6 @@ async fn endpoint_task( endpoint_adr: SocketAddr, to_sync_server_send: mpsc::Sender<ServerAsyncMessage>, mut endpoint_close_recv: broadcast::Receiver<()>, - from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>, - payloads_from_clients_send: mpsc::Sender<ClientPayload>, ) { let mut client_gen_id: ClientId = 0; let mut client_id_mappings = HashMap::new(); @@ -697,15 +678,11 @@ async fn endpoint_task( { let to_sync_server_send = to_sync_server_send.clone(); - let from_sync_server_recv = from_sync_server_recv.resubscribe(); - let payloads_from_clients_send = payloads_from_clients_send.clone(); tokio::spawn(async move { client_connection_task( connection, client_id, - to_sync_server_send, - from_sync_server_recv, - payloads_from_clients_send, + to_sync_server_send ) .await }); @@ -721,8 +698,6 @@ async fn client_connection_task( connection: quinn::Connection, client_id: ClientId, to_sync_server_send: mpsc::Sender<ServerAsyncMessage>, - mut from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>, - payloads_from_clients_send: mpsc::Sender<ClientPayload>, ) { info!( "New connection from {}, client_id: {}", @@ -732,6 +707,8 @@ async fn client_connection_task( let (client_close_send, client_close_recv) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + let (bytes_from_client_send, bytes_from_client_recv) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); let (from_channels_send, from_channels_recv) = mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); let (to_channels_send, to_channels_recv) = @@ -742,6 +719,7 @@ async fn client_connection_task( .send(ServerAsyncMessage::ClientConnected(ClientConnection { client_id: client_id, channels: HashMap::new(), + bytes_from_client_recv, close_sender: client_close_send.clone(), from_channels_recv, to_channels_send, @@ -749,13 +727,6 @@ async fn client_connection_task( .await .expect("Failed to signal connection to sync client"); - // Wait for the sync server to acknowledge the connection before spawning reception tasks. - while let Ok(ServerSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await { - if id == client_id { - break; - } - } - // Spawn a task to listen for the underlying connection being closed { let conn = connection.clone(); @@ -779,13 +750,13 @@ async fn client_connection_task( { let connection_handle = connection.clone(); let client_close_recv = client_close_recv.resubscribe(); - let payloads_incoming_send = payloads_from_clients_send.clone(); + let bytes_incoming_send = bytes_from_client_send.clone(); tokio::spawn(async move { reliable_receiver_task( client_id, connection_handle, client_close_recv, - payloads_incoming_send, + bytes_incoming_send, ) .await }); @@ -795,13 +766,13 @@ async fn client_connection_task( { let connection_handle = connection.clone(); let client_close_recv = client_close_recv.resubscribe(); - let payloads_incoming_send = payloads_from_clients_send.clone(); + let bytes_incoming_send = bytes_from_client_send.clone(); tokio::spawn(async move { unreliable_receiver_task( client_id, connection_handle, client_close_recv, - payloads_incoming_send, + bytes_incoming_send, ) .await }); @@ -819,94 +790,6 @@ async fn client_connection_task( }); } -async fn uni_receiver_task( - client_id: ClientId, - mut close_recv: broadcast::Receiver<()>, - recv: RecvStream, - payloads_from_clients_send: mpsc::Sender<ClientPayload>, -) { - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener of a Unidirectional Receiving Streams received a close signal") - } - _ = async { - let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - while let Some(Ok(msg_bytes)) = frame_recv.next().await { - // TODO Clean: error handling - payloads_from_clients_send - .send(ClientPayload { - client_id: client_id, - msg: msg_bytes.into(), - }) - .await - .unwrap(); - } - } => {} - }; -} - -async fn reliable_receiver_task( - client_id: ClientId, - connection: quinn::Connection, - mut close_recv: tokio::sync::broadcast::Receiver<()>, - payloads_from_clients_send: mpsc::Sender<ClientPayload>, -) { - let close_recv_clone = close_recv.resubscribe(); - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener for new Unidirectional Receiving Streams received a close signal for client: {}", client_id) - } - _ = async { - while let Ok(recv) = connection.accept_uni().await { - let payloads_from_clients_send = payloads_from_clients_send.clone(); - let close_recv_clone = close_recv_clone.resubscribe(); - tokio::spawn(async move { - uni_receiver_task( - client_id, - close_recv_clone, - recv, - payloads_from_clients_send - ).await; - }); - } - } => { - trace!("New Stream listener ended for client: {}", client_id) - } - } - trace!( - "All unidirectional stream receivers cleaned for client: {}", - client_id - ) -} - -async fn unreliable_receiver_task( - client_id: ClientId, - connection: quinn::Connection, - mut close_recv: broadcast::Receiver<()>, - payloads_incoming_send: mpsc::Sender<ClientPayload>, -) { - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener for unreliable datagrams received a close signal for client: {}", - client_id) - } - _ = async { - while let Ok(msg_bytes) = connection.read_datagram().await { - // TODO Clean: error handling - payloads_incoming_send.send(ClientPayload { - client_id: client_id, - msg: msg_bytes.into(), - }) - .await - .unwrap(); - } - } => { - trace!("Listener for unreliable datagrams ended for client: {}", - client_id) - } - }; -} - fn create_server(mut commands: Commands, runtime: Res<AsyncRuntime>) { commands.insert_resource(Server { endpoint: None, |