diff options
author | gilles henaux <gill.henaux@gmail.com> | 2023-01-14 22:24:26 +0100 |
---|---|---|
committer | gilles henaux <gill.henaux@gmail.com> | 2023-01-14 22:24:26 +0100 |
commit | 1a29623e2a10183703907bffb9dd59d6c4f8dccc (patch) | |
tree | 994bb7cdd161fb6e6c44d688377a9c19eefe1db4 | |
parent | f11a0360a273ba7af83e92db332c5b9be2854d29 (diff) |
[server] Implement channels on the server
-rw-r--r-- | src/server.rs | 589 |
1 files changed, 385 insertions, 204 deletions
diff --git a/src/server.rs b/src/server.rs index 932401e..d73ded5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,26 +1,35 @@ -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; +use std::{ + collections::{HashMap, HashSet}, + net::SocketAddr, + sync::Arc, + time::Duration, +}; use bevy::prelude::*; use bytes::Bytes; -use futures::sink::SinkExt; use futures_util::StreamExt; -use quinn::{Endpoint as QuinnEndpoint, SendStream, ServerConfig}; -use quinn_proto::VarInt; +use quinn::{ConnectionError, Endpoint as QuinnEndpoint, ServerConfig}; use serde::Deserialize; use tokio::{ runtime, sync::{ broadcast::{self}, - mpsc::{self, error::TryRecvError}, + mpsc::{ + self, + error::{TryRecvError, TrySendError}, + }, }, task::JoinSet, }; -use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use crate::{ server::certificate::retrieve_certificate, shared::{ - channel::{ChannelAsyncMessage, ChannelSyncMessage}, + channel::{ + channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId, + ChannelSyncMessage, ChannelType, MultiChannelId, + }, AsyncRuntime, ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, }, @@ -90,30 +99,76 @@ pub struct ClientPayload { } #[derive(Debug)] -pub(crate) enum InternalAsyncMessage { +pub(crate) enum ServerAsyncMessage { ClientConnected(ClientConnection), - ClientLostConnection(ClientId), + ClientConnectionClosed(ClientId, ConnectionError), } #[derive(Debug, Clone)] -pub(crate) enum InternalSyncMessage { +pub(crate) enum ServerSyncMessage { ClientConnectedAck(ClientId), } #[derive(Debug)] pub(crate) struct ClientConnection { client_id: ClientId, - sender: mpsc::Sender<Bytes>, + channels: HashMap<ChannelId, Channel>, close_sender: broadcast::Sender<()>, + + pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, + pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, +} + +impl ClientConnection { + /// Immediately prevents new messages from being sent on the channel and signal the channel to closes all its background tasks. + /// Before trully closing, the channel will wait for all buffered messages to be properly sent according to the channel type. + /// Can fail if the [ChannelId] is unknown, or if the channel is already closed. + pub(crate) fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> { + match self.channels.remove(&channel_id) { + Some(channel) => channel.close(), + None => Err(QuinnetError::UnknownChannel(channel_id)), + } + } + + pub(crate) fn create_channel( + &mut self, + channel_id: ChannelId, + ) -> Result<ChannelId, QuinnetError> { + let (bytes_to_channel_send, bytes_to_channel_recv) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (channel_close_send, channel_close_recv) = + mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + match self + .to_channels_send + .try_send(ChannelSyncMessage::CreateChannel { + channel_id, + bytes_to_channel_recv, + channel_close_recv, + }) { + Ok(_) => { + let channel = Channel::new(bytes_to_channel_send, channel_close_send); + self.channels.insert(channel_id, channel); + Ok(channel_id) + } + Err(err) => match err { + TrySendError::Full(_) => Err(QuinnetError::FullQueue), + TrySendError::Closed(_) => Err(QuinnetError::InternalChannelClosed), + }, + } + } } pub struct Endpoint { clients: HashMap<ClientId, ClientConnection>, - payloads_recv: mpsc::Receiver<ClientPayload>, + channels: HashSet<ChannelId>, + default_channel: Option<ChannelId>, + last_gen_id: MultiChannelId, + payloads_from_clients_recv: mpsc::Receiver<ClientPayload>, close_sender: broadcast::Sender<()>, - pub(crate) from_async_server_recv: mpsc::Receiver<InternalAsyncMessage>, - pub(crate) to_async_server_send: broadcast::Sender<InternalSyncMessage>, + pub(crate) from_async_server_recv: mpsc::Receiver<ServerAsyncMessage>, + pub(crate) to_async_server_send: broadcast::Sender<ServerSyncMessage>, } impl Endpoint { @@ -139,13 +194,45 @@ impl Endpoint { } } + pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> { + match self.payloads_from_clients_recv.try_recv() { + Ok(msg) => Ok(Some(msg)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), + }, + } + } + + pub fn try_receive_payload(&mut self) -> Option<ClientPayload> { + match self.receive_payload() { + Ok(payload) => payload, + Err(err) => { + error!("try_receive_payload: {}", err); + None + } + } + } + pub fn send_message<T: serde::Serialize>( &mut self, client_id: ClientId, message: T, ) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.send_message_on(client_id, channel, message), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn send_message_on<T: serde::Serialize>( + &mut self, + client_id: ClientId, + channel_id: ChannelId, + message: T, + ) -> Result<(), QuinnetError> { match bincode::serialize(&message) { - Ok(payload) => Ok(self.send_payload(client_id, payload)?), + Ok(payload) => Ok(self.send_payload_on(client_id, channel_id, payload)?), Err(_) => Err(QuinnetError::Serialization), } } @@ -157,15 +244,39 @@ impl Endpoint { } } + pub fn try_send_message_on<T: serde::Serialize>( + &mut self, + client_id: ClientId, + channel_id: ChannelId, + message: T, + ) { + match self.send_message_on(client_id, channel_id, message) { + Ok(_) => {} + Err(err) => error!("try_send_message: {}", err), + } + } + pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>( &self, client_ids: I, message: T, ) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.send_group_message_on(client_ids, channel, message), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn send_group_message_on<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>( + &self, + client_ids: I, + channel_id: ChannelId, + message: T, + ) -> Result<(), QuinnetError> { match bincode::serialize(&message) { Ok(payload) => { for id in client_ids { - self.send_payload(*id, payload.clone())?; + self.send_payload_on(*id, channel_id, payload.clone())?; } Ok(()) } @@ -184,9 +295,32 @@ impl Endpoint { } } + pub fn try_send_group_message_on<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>( + &self, + client_ids: I, + channel_id: ChannelId, + message: T, + ) { + match self.send_group_message_on(client_ids, channel_id, message) { + Ok(_) => {} + Err(err) => error!("try_send_group_message: {}", err), + } + } + pub fn broadcast_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.broadcast_message_on(channel, message), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn broadcast_message_on<T: serde::Serialize>( + &self, + channel_id: ChannelId, + message: T, + ) -> Result<(), QuinnetError> { match bincode::serialize(&message) { - Ok(payload) => Ok(self.broadcast_payload(payload)?), + Ok(payload) => Ok(self.broadcast_payload_on(channel_id, payload)?), Err(_) => Err(QuinnetError::Serialization), } } @@ -198,19 +332,33 @@ impl Endpoint { } } + pub fn try_broadcast_message_on<T: serde::Serialize>(&self, message: T, channel_id: ChannelId) { + match self.broadcast_message_on(channel_id, message) { + Ok(_) => {} + Err(err) => error!("try_broadcast_message: {}", err), + } + } + pub fn broadcast_payload<T: Into<Bytes> + Clone>( &self, payload: T, ) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.broadcast_payload_on(channel, payload), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn broadcast_payload_on<T: Into<Bytes> + Clone>( + &self, + channel_id: ChannelId, + payload: T, + ) -> Result<(), QuinnetError> { + let payload = payload.into(); for (_, client_connection) in self.clients.iter() { - match client_connection.sender.try_send(payload.clone().into()) { - Ok(_) => {} - Err(err) => match err { - mpsc::error::TrySendError::Full(_) => return Err(QuinnetError::FullQueue), - mpsc::error::TrySendError::Closed(_) => { - return Err(QuinnetError::InternalChannelClosed) - } - }, + match client_connection.channels.get(&channel_id) { + Some(channel) => channel.send_payload(payload.clone())?, + None => return Err(QuinnetError::UnknownChannel(channel_id)), }; } Ok(()) @@ -223,20 +371,38 @@ impl Endpoint { } } + pub fn try_broadcast_payload_on<T: Into<Bytes> + Clone>( + &self, + channel_id: ChannelId, + payload: T, + ) { + match self.broadcast_payload_on(channel_id, payload) { + Ok(_) => {} + Err(err) => error!("try_broadcast_payload_on: {}", err), + } + } + pub fn send_payload<T: Into<Bytes>>( &self, client_id: ClientId, payload: T, ) -> Result<(), QuinnetError> { - if let Some(client) = self.clients.get(&client_id) { - match client.sender.try_send(payload.into()) { - Ok(_) => Ok(()), - Err(err) => match err { - mpsc::error::TrySendError::Full(_) => Err(QuinnetError::FullQueue), - mpsc::error::TrySendError::Closed(_) => { - Err(QuinnetError::InternalChannelClosed) - } - }, + match self.default_channel { + Some(channel) => self.send_payload_on(client_id, channel, payload), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn send_payload_on<T: Into<Bytes>>( + &self, + client_id: ClientId, + channel_id: ChannelId, + payload: T, + ) -> Result<(), QuinnetError> { + if let Some(client_connection) = self.clients.get(&client_id) { + match client_connection.channels.get(&channel_id) { + Some(channel) => channel.send_payload(payload), + None => return Err(QuinnetError::UnknownChannel(channel_id)), } } else { Err(QuinnetError::UnknownClient(client_id)) @@ -250,23 +416,15 @@ impl Endpoint { } } - pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> { - match self.payloads_recv.try_recv() { - Ok(msg) => Ok(Some(msg)), - Err(err) => match err { - TryRecvError::Empty => Ok(None), - TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), - }, - } - } - - pub fn try_receive_payload(&mut self) -> Option<ClientPayload> { - match self.receive_payload() { - Ok(payload) => payload, - Err(err) => { - error!("try_receive_payload: {}", err); - None - } + pub fn try_send_payload_on<T: Into<Bytes>>( + &self, + client_id: ClientId, + channel_id: ChannelId, + payload: T, + ) { + match self.send_payload_on(client_id, channel_id, payload) { + Ok(_) => {} + Err(err) => error!("try_send_payload_on: {}", err), } } @@ -280,6 +438,16 @@ impl Endpoint { } } + pub fn try_disconnect_client(&mut self, client_id: ClientId) { + match self.disconnect_client(client_id) { + Ok(_) => (), + Err(err) => error!( + "Failed to properly disconnect client {}: {}", + client_id, err + ), + } + } + pub fn disconnect_all_clients(&mut self) -> Result<(), QuinnetError> { for client_id in self.clients.keys().cloned().collect::<Vec<ClientId>>() { self.disconnect_client(client_id)?; @@ -287,12 +455,77 @@ impl Endpoint { Ok(()) } - pub(crate) fn close_incoming_connections_handler(&mut self) -> Result<(), QuinnetError> { + pub fn open_channel(&mut self, channel_type: ChannelType) -> Result<ChannelId, QuinnetError> { + let channel_id = get_channel_id_from_type(channel_type, || { + self.last_gen_id += 1; + self.last_gen_id + }); + match self.channels.contains(&channel_id) { + true => Ok(channel_id), + false => self.create_channel(channel_id), + } + } + + pub fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> { + match self.channels.remove(&channel_id) { + true => { + if Some(channel_id) == self.default_channel { + self.default_channel = None; + } + for (_, connection) in self.clients.iter_mut() { + connection.close_channel(channel_id)?; + } + Ok(()) + } + false => Err(QuinnetError::UnknownChannel(channel_id)), + } + } + + fn open_default_channels(&mut self) -> Result<ChannelId, QuinnetError> { + self.open_channel(ChannelType::OrderedReliable)?; + self.open_channel(ChannelType::UnorderedReliable)?; + self.open_channel(ChannelType::Unreliable) + } + + fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> { + for (_, client_connection) in self.clients.iter_mut() { + client_connection.create_channel(channel_id)?; + } + self.channels.insert(channel_id); + if self.default_channel.is_none() { + self.default_channel = Some(channel_id); + } + Ok(channel_id) + } + + fn close_incoming_connections_handler(&mut self) -> Result<(), QuinnetError> { match self.close_sender.send(()) { Ok(_) => Ok(()), Err(_) => Err(QuinnetError::InternalChannelClosed), } } + + fn handle_connection(&mut self, mut connection: ClientConnection) -> Result<(), QuinnetError> { + match self.clients.contains_key(&connection.client_id) { + true => todo!(), + false => { + for channel_id in self.channels.iter() { + connection.create_channel(*channel_id)?; + } + + match self + .to_async_server_send + .send(ServerSyncMessage::ClientConnectedAck(connection.client_id)) + { + Ok(_) => { + self.clients.insert(connection.client_id, connection); + Ok(()) + } + Err(_) => Err(QuinnetError::InternalChannelClosed), + } + } + } + } } #[derive(Resource)] @@ -340,10 +573,9 @@ impl Server { let (payloads_from_clients_send, payloads_from_clients_recv) = mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); let (to_sync_server_send, from_async_server_recv) = - mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + mpsc::channel::<ServerAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); let (to_async_server_send, from_sync_server_recv) = - broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - // Create a close channel for this endpoint + broadcast::channel::<ServerSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); let (endpoint_close_send, endpoint_close_recv) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); @@ -361,13 +593,20 @@ impl Server { .await; }); - self.endpoint = Some(Endpoint { - clients: HashMap::new(), - payloads_recv: payloads_from_clients_recv, - close_sender: endpoint_close_send, - from_async_server_recv, - to_async_server_send: to_async_server_send.clone(), - }); + { + let mut endpoint = Endpoint { + clients: HashMap::new(), + channels: HashSet::new(), + default_channel: None, + last_gen_id: 0, + payloads_from_clients_recv, + close_sender: endpoint_close_send, + from_async_server_recv, + to_async_server_send: to_async_server_send.clone(), + }; + endpoint.open_default_channels()?; + self.endpoint = Some(endpoint); + } Ok(server_cert) } @@ -394,9 +633,9 @@ impl Server { async fn endpoint_task( endpoint_config: ServerConfig, endpoint_adr: SocketAddr, - to_sync_server_send: mpsc::Sender<InternalAsyncMessage>, + to_sync_server_send: mpsc::Sender<ServerAsyncMessage>, mut endpoint_close_recv: broadcast::Receiver<()>, - mut from_sync_server_recv: broadcast::Receiver<InternalSyncMessage>, + from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>, payloads_from_clients_send: mpsc::Sender<ClientPayload>, ) { let mut client_gen_id: ClientId = 0; @@ -418,184 +657,110 @@ async fn endpoint_task( let client_id = client_gen_id; client_id_mappings.insert(connection.stable_id(), client_id); - handle_client_connection( - connection, - client_id, - &to_sync_server_send, - &mut from_sync_server_recv, - payloads_from_clients_send.clone(), - ) - .await; + { + let to_sync_server_send = to_sync_server_send.clone(); + let from_sync_server_recv = from_sync_server_recv.resubscribe(); + let payloads_from_clients_send = payloads_from_clients_send.clone(); + tokio::spawn(async move { + client_connection_task( + connection, + client_id, + to_sync_server_send, + from_sync_server_recv, + payloads_from_clients_send, + ) + .await + }); + } }, } - } } => {} } } -async fn handle_client_connection( +async fn client_connection_task( connection: quinn::Connection, client_id: ClientId, - to_sync_server_send: &mpsc::Sender<InternalAsyncMessage>, - from_sync_server_recv: &mut broadcast::Receiver<InternalSyncMessage>, + to_sync_server_send: mpsc::Sender<ServerAsyncMessage>, + mut from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>, payloads_from_clients_send: mpsc::Sender<ClientPayload>, ) { info!( - "New connection from {}, client_id: {}, stable_id : {}", + "New connection from {}, client_id: {}", connection.remote_address(), - client_id, - connection.stable_id() + client_id ); let (client_close_send, client_close_recv) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - - // Create an ordered reliable send channel for this client - let (to_client_sender, to_client_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - - let _client_sender = { - let to_sync_server_send_clone = to_sync_server_send.clone(); - let close_sender_clone = client_close_send.clone(); - let connection_clone = connection.clone(); - tokio::spawn(async move { - client_sender_task( - client_id, - connection_clone, - to_client_receiver, - client_close_recv, - close_sender_clone, - to_sync_server_send_clone, - ) - .await - }) - }; + let (from_channels_send, from_channels_recv) = + mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_channels_send, to_channels_recv) = + mpsc::channel::<ChannelSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); // Signal the sync server of this new connection to_sync_server_send - .send(InternalAsyncMessage::ClientConnected(ClientConnection { + .send(ServerAsyncMessage::ClientConnected(ClientConnection { client_id: client_id, - sender: to_client_sender, + channels: HashMap::new(), close_sender: client_close_send.clone(), + from_channels_recv, + to_channels_send, })) .await .expect("Failed to signal connection to sync client"); // Wait for the sync server to acknowledge the connection before spawning reception tasks. - while let Ok(InternalSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await { + while let Ok(ServerSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await { if id == client_id { break; } } - let _client_close_wait = { + // Spawn a task to listen for the underlying connection being closed + { let conn = connection.clone(); - let close_sender = client_close_send.clone(); let to_sync_server = to_sync_server_send.clone(); tokio::spawn(async move { let conn_err = conn.closed().await; - info!("Client {} disconnected: {}", client_id, conn_err); - close_sender.send(()).ok(); + info!("Client {} connection closed: {}", client_id, conn_err); to_sync_server - .send(InternalAsyncMessage::ClientLostConnection(client_id)) + .send(ServerAsyncMessage::ClientConnectionClosed( + client_id, conn_err, + )) .await .expect("Failed to signal connection lost to sync server"); }); }; // Spawn a task to listen for streams opened by this client - let _client_receiver = tokio::spawn(async move { - client_receiver_task( - client_id, + { + let conn = connection.clone(); + let client_close_recv = client_close_recv.resubscribe(); + tokio::spawn(async move { + client_receiver_task( + client_id, + conn, + client_close_recv, + payloads_from_clients_send, + ) + .await + }); + } + + // Spawn a task to handle channels for this client + tokio::spawn(async move { + channels_task( connection, - client_close_send.subscribe(), - payloads_from_clients_send, + client_close_recv, + to_channels_recv, + from_channels_send, ) .await }); } -async fn client_sender_task( - client_id: ClientId, - connection: quinn::Connection, - mut to_client_receiver: tokio::sync::mpsc::Receiver<Bytes>, - mut client_close_recv: tokio::sync::broadcast::Receiver<()>, - close_sender: tokio::sync::broadcast::Sender<()>, - to_sync_server_send: mpsc::Sender<InternalAsyncMessage>, -) { - let send_stream = connection.open_uni().await.expect( - format!( - "Failed to open unidirectional send stream for client: {}", - client_id - ) - .as_str(), - ); - - let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); - - tokio::select! { - _ = client_close_recv.recv() => { - trace!("Unidirectional send stream forced to disconnected for client: {}", client_id) - } - _ = async { - while let Some(msg_bytes) = to_client_receiver.recv().await { - send_msg( - client_id, - &close_sender, - &to_sync_server_send, - &mut framed_send_stream, - msg_bytes, - ) - .await - } - } => {} - } - while let Ok(msg_bytes) = to_client_receiver.try_recv() { - if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { - error!("Error while sending to client {}: {}", client_id, err); - }; - } - if let Err(err) = framed_send_stream.flush().await { - error!( - "Error while flushing stream to client {}: {}", - client_id, err - ); - } - if let Err(err) = framed_send_stream.into_inner().finish().await { - error!( - "Failed to shutdown stream gracefully for client {}: {}", - client_id, err - ); - } - connection.close(VarInt::from_u32(0), "closed".as_bytes()); -} - -async fn send_msg( - client_id: ClientId, - close_sender: &tokio::sync::broadcast::Sender<()>, - to_sync_server: &mpsc::Sender<InternalAsyncMessage>, - framed_send_stream: &mut FramedWrite<SendStream, LengthDelimitedCodec>, - msg_bytes: Bytes, -) { - // TODO Perf: Batch frames for a send_all - if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { - error!("Error while sending to client {}: {}", client_id, err); - error!("Client {} seems disconnected, closing resources", client_id); - // Emit ClientLostConnection to properly update the server about this client state. - // Raise ClientLostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed. - to_sync_server - .send(InternalAsyncMessage::ClientLostConnection(client_id)) - .await - .expect("Failed to signal connection lost to sync server"); - if let Err(_) = close_sender.send(()) { - error!( - "Failed to close all client streams & resources for client {}", - client_id - ) - } - }; -} - async fn client_receiver_task( client_id: ClientId, connection: quinn::Connection, @@ -654,25 +819,41 @@ fn update_sync_server( if let Some(endpoint) = server.get_endpoint_mut() { while let Ok(message) = endpoint.from_async_server_recv.try_recv() { match message { - InternalAsyncMessage::ClientConnected(connection) => { + ServerAsyncMessage::ClientConnected(connection) => { let id = connection.client_id; - endpoint.clients.insert(id, connection); - endpoint - .to_async_server_send - .send(InternalSyncMessage::ClientConnectedAck(id)) - .unwrap(); - connection_events.send(ConnectionEvent { id: id }); + match endpoint.handle_connection(connection) { + Ok(_) => connection_events.send(ConnectionEvent { id }), + Err(_) => error!("Failed to handle connection of client {}", id), + }; } - InternalAsyncMessage::ClientLostConnection(client_id) => { - match endpoint.clients.remove(&client_id) { - Some(_) => { + ServerAsyncMessage::ClientConnectionClosed(client_id, _) => { + match endpoint.clients.contains_key(&client_id) { + true => { + endpoint.try_disconnect_client(client_id); connection_lost_events.send(ConnectionLostEvent { id: client_id }) } - None => (), + false => (), + } + } + } + } + + let mut lost_clients = HashSet::new(); + for (client_id, connection) in endpoint.clients.iter_mut() { + while let Ok(message) = connection.from_channels_recv.try_recv() { + match message { + ChannelAsyncMessage::LostConnection => { + if !lost_clients.contains(client_id) { + lost_clients.insert(*client_id); + connection_lost_events.send(ConnectionLostEvent { id: *client_id }) + } } } } } + for client_id in lost_clients { + endpoint.try_disconnect_client(client_id); + } } } |