aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgilles henaux <gill.henaux@gmail.com>2023-01-14 22:24:26 +0100
committergilles henaux <gill.henaux@gmail.com>2023-01-14 22:24:26 +0100
commit1a29623e2a10183703907bffb9dd59d6c4f8dccc (patch)
tree994bb7cdd161fb6e6c44d688377a9c19eefe1db4
parentf11a0360a273ba7af83e92db332c5b9be2854d29 (diff)
[server] Implement channels on the server
-rw-r--r--src/server.rs589
1 files changed, 385 insertions, 204 deletions
diff --git a/src/server.rs b/src/server.rs
index 932401e..d73ded5 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,26 +1,35 @@
-use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
+use std::{
+ collections::{HashMap, HashSet},
+ net::SocketAddr,
+ sync::Arc,
+ time::Duration,
+};
use bevy::prelude::*;
use bytes::Bytes;
-use futures::sink::SinkExt;
use futures_util::StreamExt;
-use quinn::{Endpoint as QuinnEndpoint, SendStream, ServerConfig};
-use quinn_proto::VarInt;
+use quinn::{ConnectionError, Endpoint as QuinnEndpoint, ServerConfig};
use serde::Deserialize;
use tokio::{
runtime,
sync::{
broadcast::{self},
- mpsc::{self, error::TryRecvError},
+ mpsc::{
+ self,
+ error::{TryRecvError, TrySendError},
+ },
},
task::JoinSet,
};
-use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
+use tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use crate::{
server::certificate::retrieve_certificate,
shared::{
- channel::{ChannelAsyncMessage, ChannelSyncMessage},
+ channel::{
+ channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId,
+ ChannelSyncMessage, ChannelType, MultiChannelId,
+ },
AsyncRuntime, ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S,
DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE,
},
@@ -90,30 +99,76 @@ pub struct ClientPayload {
}
#[derive(Debug)]
-pub(crate) enum InternalAsyncMessage {
+pub(crate) enum ServerAsyncMessage {
ClientConnected(ClientConnection),
- ClientLostConnection(ClientId),
+ ClientConnectionClosed(ClientId, ConnectionError),
}
#[derive(Debug, Clone)]
-pub(crate) enum InternalSyncMessage {
+pub(crate) enum ServerSyncMessage {
ClientConnectedAck(ClientId),
}
#[derive(Debug)]
pub(crate) struct ClientConnection {
client_id: ClientId,
- sender: mpsc::Sender<Bytes>,
+ channels: HashMap<ChannelId, Channel>,
close_sender: broadcast::Sender<()>,
+
+ pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>,
+ pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>,
+}
+
+impl ClientConnection {
+ /// Immediately prevents new messages from being sent on the channel and signal the channel to closes all its background tasks.
+ /// Before trully closing, the channel will wait for all buffered messages to be properly sent according to the channel type.
+ /// Can fail if the [ChannelId] is unknown, or if the channel is already closed.
+ pub(crate) fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> {
+ match self.channels.remove(&channel_id) {
+ Some(channel) => channel.close(),
+ None => Err(QuinnetError::UnknownChannel(channel_id)),
+ }
+ }
+
+ pub(crate) fn create_channel(
+ &mut self,
+ channel_id: ChannelId,
+ ) -> Result<ChannelId, QuinnetError> {
+ let (bytes_to_channel_send, bytes_to_channel_recv) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let (channel_close_send, channel_close_recv) =
+ mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
+
+ match self
+ .to_channels_send
+ .try_send(ChannelSyncMessage::CreateChannel {
+ channel_id,
+ bytes_to_channel_recv,
+ channel_close_recv,
+ }) {
+ Ok(_) => {
+ let channel = Channel::new(bytes_to_channel_send, channel_close_send);
+ self.channels.insert(channel_id, channel);
+ Ok(channel_id)
+ }
+ Err(err) => match err {
+ TrySendError::Full(_) => Err(QuinnetError::FullQueue),
+ TrySendError::Closed(_) => Err(QuinnetError::InternalChannelClosed),
+ },
+ }
+ }
}
pub struct Endpoint {
clients: HashMap<ClientId, ClientConnection>,
- payloads_recv: mpsc::Receiver<ClientPayload>,
+ channels: HashSet<ChannelId>,
+ default_channel: Option<ChannelId>,
+ last_gen_id: MultiChannelId,
+ payloads_from_clients_recv: mpsc::Receiver<ClientPayload>,
close_sender: broadcast::Sender<()>,
- pub(crate) from_async_server_recv: mpsc::Receiver<InternalAsyncMessage>,
- pub(crate) to_async_server_send: broadcast::Sender<InternalSyncMessage>,
+ pub(crate) from_async_server_recv: mpsc::Receiver<ServerAsyncMessage>,
+ pub(crate) to_async_server_send: broadcast::Sender<ServerSyncMessage>,
}
impl Endpoint {
@@ -139,13 +194,45 @@ impl Endpoint {
}
}
+ pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> {
+ match self.payloads_from_clients_recv.try_recv() {
+ Ok(msg) => Ok(Some(msg)),
+ Err(err) => match err {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed),
+ },
+ }
+ }
+
+ pub fn try_receive_payload(&mut self) -> Option<ClientPayload> {
+ match self.receive_payload() {
+ Ok(payload) => payload,
+ Err(err) => {
+ error!("try_receive_payload: {}", err);
+ None
+ }
+ }
+ }
+
pub fn send_message<T: serde::Serialize>(
&mut self,
client_id: ClientId,
message: T,
) -> Result<(), QuinnetError> {
+ match self.default_channel {
+ Some(channel) => self.send_message_on(client_id, channel, message),
+ None => Err(QuinnetError::NoDefaultChannel),
+ }
+ }
+
+ pub fn send_message_on<T: serde::Serialize>(
+ &mut self,
+ client_id: ClientId,
+ channel_id: ChannelId,
+ message: T,
+ ) -> Result<(), QuinnetError> {
match bincode::serialize(&message) {
- Ok(payload) => Ok(self.send_payload(client_id, payload)?),
+ Ok(payload) => Ok(self.send_payload_on(client_id, channel_id, payload)?),
Err(_) => Err(QuinnetError::Serialization),
}
}
@@ -157,15 +244,39 @@ impl Endpoint {
}
}
+ pub fn try_send_message_on<T: serde::Serialize>(
+ &mut self,
+ client_id: ClientId,
+ channel_id: ChannelId,
+ message: T,
+ ) {
+ match self.send_message_on(client_id, channel_id, message) {
+ Ok(_) => {}
+ Err(err) => error!("try_send_message: {}", err),
+ }
+ }
+
pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>(
&self,
client_ids: I,
message: T,
) -> Result<(), QuinnetError> {
+ match self.default_channel {
+ Some(channel) => self.send_group_message_on(client_ids, channel, message),
+ None => Err(QuinnetError::NoDefaultChannel),
+ }
+ }
+
+ pub fn send_group_message_on<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>(
+ &self,
+ client_ids: I,
+ channel_id: ChannelId,
+ message: T,
+ ) -> Result<(), QuinnetError> {
match bincode::serialize(&message) {
Ok(payload) => {
for id in client_ids {
- self.send_payload(*id, payload.clone())?;
+ self.send_payload_on(*id, channel_id, payload.clone())?;
}
Ok(())
}
@@ -184,9 +295,32 @@ impl Endpoint {
}
}
+ pub fn try_send_group_message_on<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>(
+ &self,
+ client_ids: I,
+ channel_id: ChannelId,
+ message: T,
+ ) {
+ match self.send_group_message_on(client_ids, channel_id, message) {
+ Ok(_) => {}
+ Err(err) => error!("try_send_group_message: {}", err),
+ }
+ }
+
pub fn broadcast_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> {
+ match self.default_channel {
+ Some(channel) => self.broadcast_message_on(channel, message),
+ None => Err(QuinnetError::NoDefaultChannel),
+ }
+ }
+
+ pub fn broadcast_message_on<T: serde::Serialize>(
+ &self,
+ channel_id: ChannelId,
+ message: T,
+ ) -> Result<(), QuinnetError> {
match bincode::serialize(&message) {
- Ok(payload) => Ok(self.broadcast_payload(payload)?),
+ Ok(payload) => Ok(self.broadcast_payload_on(channel_id, payload)?),
Err(_) => Err(QuinnetError::Serialization),
}
}
@@ -198,19 +332,33 @@ impl Endpoint {
}
}
+ pub fn try_broadcast_message_on<T: serde::Serialize>(&self, message: T, channel_id: ChannelId) {
+ match self.broadcast_message_on(channel_id, message) {
+ Ok(_) => {}
+ Err(err) => error!("try_broadcast_message: {}", err),
+ }
+ }
+
pub fn broadcast_payload<T: Into<Bytes> + Clone>(
&self,
payload: T,
) -> Result<(), QuinnetError> {
+ match self.default_channel {
+ Some(channel) => self.broadcast_payload_on(channel, payload),
+ None => Err(QuinnetError::NoDefaultChannel),
+ }
+ }
+
+ pub fn broadcast_payload_on<T: Into<Bytes> + Clone>(
+ &self,
+ channel_id: ChannelId,
+ payload: T,
+ ) -> Result<(), QuinnetError> {
+ let payload = payload.into();
for (_, client_connection) in self.clients.iter() {
- match client_connection.sender.try_send(payload.clone().into()) {
- Ok(_) => {}
- Err(err) => match err {
- mpsc::error::TrySendError::Full(_) => return Err(QuinnetError::FullQueue),
- mpsc::error::TrySendError::Closed(_) => {
- return Err(QuinnetError::InternalChannelClosed)
- }
- },
+ match client_connection.channels.get(&channel_id) {
+ Some(channel) => channel.send_payload(payload.clone())?,
+ None => return Err(QuinnetError::UnknownChannel(channel_id)),
};
}
Ok(())
@@ -223,20 +371,38 @@ impl Endpoint {
}
}
+ pub fn try_broadcast_payload_on<T: Into<Bytes> + Clone>(
+ &self,
+ channel_id: ChannelId,
+ payload: T,
+ ) {
+ match self.broadcast_payload_on(channel_id, payload) {
+ Ok(_) => {}
+ Err(err) => error!("try_broadcast_payload_on: {}", err),
+ }
+ }
+
pub fn send_payload<T: Into<Bytes>>(
&self,
client_id: ClientId,
payload: T,
) -> Result<(), QuinnetError> {
- if let Some(client) = self.clients.get(&client_id) {
- match client.sender.try_send(payload.into()) {
- Ok(_) => Ok(()),
- Err(err) => match err {
- mpsc::error::TrySendError::Full(_) => Err(QuinnetError::FullQueue),
- mpsc::error::TrySendError::Closed(_) => {
- Err(QuinnetError::InternalChannelClosed)
- }
- },
+ match self.default_channel {
+ Some(channel) => self.send_payload_on(client_id, channel, payload),
+ None => Err(QuinnetError::NoDefaultChannel),
+ }
+ }
+
+ pub fn send_payload_on<T: Into<Bytes>>(
+ &self,
+ client_id: ClientId,
+ channel_id: ChannelId,
+ payload: T,
+ ) -> Result<(), QuinnetError> {
+ if let Some(client_connection) = self.clients.get(&client_id) {
+ match client_connection.channels.get(&channel_id) {
+ Some(channel) => channel.send_payload(payload),
+ None => return Err(QuinnetError::UnknownChannel(channel_id)),
}
} else {
Err(QuinnetError::UnknownClient(client_id))
@@ -250,23 +416,15 @@ impl Endpoint {
}
}
- pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> {
- match self.payloads_recv.try_recv() {
- Ok(msg) => Ok(Some(msg)),
- Err(err) => match err {
- TryRecvError::Empty => Ok(None),
- TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed),
- },
- }
- }
-
- pub fn try_receive_payload(&mut self) -> Option<ClientPayload> {
- match self.receive_payload() {
- Ok(payload) => payload,
- Err(err) => {
- error!("try_receive_payload: {}", err);
- None
- }
+ pub fn try_send_payload_on<T: Into<Bytes>>(
+ &self,
+ client_id: ClientId,
+ channel_id: ChannelId,
+ payload: T,
+ ) {
+ match self.send_payload_on(client_id, channel_id, payload) {
+ Ok(_) => {}
+ Err(err) => error!("try_send_payload_on: {}", err),
}
}
@@ -280,6 +438,16 @@ impl Endpoint {
}
}
+ pub fn try_disconnect_client(&mut self, client_id: ClientId) {
+ match self.disconnect_client(client_id) {
+ Ok(_) => (),
+ Err(err) => error!(
+ "Failed to properly disconnect client {}: {}",
+ client_id, err
+ ),
+ }
+ }
+
pub fn disconnect_all_clients(&mut self) -> Result<(), QuinnetError> {
for client_id in self.clients.keys().cloned().collect::<Vec<ClientId>>() {
self.disconnect_client(client_id)?;
@@ -287,12 +455,77 @@ impl Endpoint {
Ok(())
}
- pub(crate) fn close_incoming_connections_handler(&mut self) -> Result<(), QuinnetError> {
+ pub fn open_channel(&mut self, channel_type: ChannelType) -> Result<ChannelId, QuinnetError> {
+ let channel_id = get_channel_id_from_type(channel_type, || {
+ self.last_gen_id += 1;
+ self.last_gen_id
+ });
+ match self.channels.contains(&channel_id) {
+ true => Ok(channel_id),
+ false => self.create_channel(channel_id),
+ }
+ }
+
+ pub fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> {
+ match self.channels.remove(&channel_id) {
+ true => {
+ if Some(channel_id) == self.default_channel {
+ self.default_channel = None;
+ }
+ for (_, connection) in self.clients.iter_mut() {
+ connection.close_channel(channel_id)?;
+ }
+ Ok(())
+ }
+ false => Err(QuinnetError::UnknownChannel(channel_id)),
+ }
+ }
+
+ fn open_default_channels(&mut self) -> Result<ChannelId, QuinnetError> {
+ self.open_channel(ChannelType::OrderedReliable)?;
+ self.open_channel(ChannelType::UnorderedReliable)?;
+ self.open_channel(ChannelType::Unreliable)
+ }
+
+ fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> {
+ for (_, client_connection) in self.clients.iter_mut() {
+ client_connection.create_channel(channel_id)?;
+ }
+ self.channels.insert(channel_id);
+ if self.default_channel.is_none() {
+ self.default_channel = Some(channel_id);
+ }
+ Ok(channel_id)
+ }
+
+ fn close_incoming_connections_handler(&mut self) -> Result<(), QuinnetError> {
match self.close_sender.send(()) {
Ok(_) => Ok(()),
Err(_) => Err(QuinnetError::InternalChannelClosed),
}
}
+
+ fn handle_connection(&mut self, mut connection: ClientConnection) -> Result<(), QuinnetError> {
+ match self.clients.contains_key(&connection.client_id) {
+ true => todo!(),
+ false => {
+ for channel_id in self.channels.iter() {
+ connection.create_channel(*channel_id)?;
+ }
+
+ match self
+ .to_async_server_send
+ .send(ServerSyncMessage::ClientConnectedAck(connection.client_id))
+ {
+ Ok(_) => {
+ self.clients.insert(connection.client_id, connection);
+ Ok(())
+ }
+ Err(_) => Err(QuinnetError::InternalChannelClosed),
+ }
+ }
+ }
+ }
}
#[derive(Resource)]
@@ -340,10 +573,9 @@ impl Server {
let (payloads_from_clients_send, payloads_from_clients_recv) =
mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
let (to_sync_server_send, from_async_server_recv) =
- mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+ mpsc::channel::<ServerAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
let (to_async_server_send, from_sync_server_recv) =
- broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
- // Create a close channel for this endpoint
+ broadcast::channel::<ServerSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
let (endpoint_close_send, endpoint_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
@@ -361,13 +593,20 @@ impl Server {
.await;
});
- self.endpoint = Some(Endpoint {
- clients: HashMap::new(),
- payloads_recv: payloads_from_clients_recv,
- close_sender: endpoint_close_send,
- from_async_server_recv,
- to_async_server_send: to_async_server_send.clone(),
- });
+ {
+ let mut endpoint = Endpoint {
+ clients: HashMap::new(),
+ channels: HashSet::new(),
+ default_channel: None,
+ last_gen_id: 0,
+ payloads_from_clients_recv,
+ close_sender: endpoint_close_send,
+ from_async_server_recv,
+ to_async_server_send: to_async_server_send.clone(),
+ };
+ endpoint.open_default_channels()?;
+ self.endpoint = Some(endpoint);
+ }
Ok(server_cert)
}
@@ -394,9 +633,9 @@ impl Server {
async fn endpoint_task(
endpoint_config: ServerConfig,
endpoint_adr: SocketAddr,
- to_sync_server_send: mpsc::Sender<InternalAsyncMessage>,
+ to_sync_server_send: mpsc::Sender<ServerAsyncMessage>,
mut endpoint_close_recv: broadcast::Receiver<()>,
- mut from_sync_server_recv: broadcast::Receiver<InternalSyncMessage>,
+ from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>,
payloads_from_clients_send: mpsc::Sender<ClientPayload>,
) {
let mut client_gen_id: ClientId = 0;
@@ -418,184 +657,110 @@ async fn endpoint_task(
let client_id = client_gen_id;
client_id_mappings.insert(connection.stable_id(), client_id);
- handle_client_connection(
- connection,
- client_id,
- &to_sync_server_send,
- &mut from_sync_server_recv,
- payloads_from_clients_send.clone(),
- )
- .await;
+ {
+ let to_sync_server_send = to_sync_server_send.clone();
+ let from_sync_server_recv = from_sync_server_recv.resubscribe();
+ let payloads_from_clients_send = payloads_from_clients_send.clone();
+ tokio::spawn(async move {
+ client_connection_task(
+ connection,
+ client_id,
+ to_sync_server_send,
+ from_sync_server_recv,
+ payloads_from_clients_send,
+ )
+ .await
+ });
+ }
},
}
-
}
} => {}
}
}
-async fn handle_client_connection(
+async fn client_connection_task(
connection: quinn::Connection,
client_id: ClientId,
- to_sync_server_send: &mpsc::Sender<InternalAsyncMessage>,
- from_sync_server_recv: &mut broadcast::Receiver<InternalSyncMessage>,
+ to_sync_server_send: mpsc::Sender<ServerAsyncMessage>,
+ mut from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>,
payloads_from_clients_send: mpsc::Sender<ClientPayload>,
) {
info!(
- "New connection from {}, client_id: {}, stable_id : {}",
+ "New connection from {}, client_id: {}",
connection.remote_address(),
- client_id,
- connection.stable_id()
+ client_id
);
let (client_close_send, client_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
-
- // Create an ordered reliable send channel for this client
- let (to_client_sender, to_client_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
-
- let _client_sender = {
- let to_sync_server_send_clone = to_sync_server_send.clone();
- let close_sender_clone = client_close_send.clone();
- let connection_clone = connection.clone();
- tokio::spawn(async move {
- client_sender_task(
- client_id,
- connection_clone,
- to_client_receiver,
- client_close_recv,
- close_sender_clone,
- to_sync_server_send_clone,
- )
- .await
- })
- };
+ let (from_channels_send, from_channels_recv) =
+ mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+ let (to_channels_send, to_channels_recv) =
+ mpsc::channel::<ChannelSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
// Signal the sync server of this new connection
to_sync_server_send
- .send(InternalAsyncMessage::ClientConnected(ClientConnection {
+ .send(ServerAsyncMessage::ClientConnected(ClientConnection {
client_id: client_id,
- sender: to_client_sender,
+ channels: HashMap::new(),
close_sender: client_close_send.clone(),
+ from_channels_recv,
+ to_channels_send,
}))
.await
.expect("Failed to signal connection to sync client");
// Wait for the sync server to acknowledge the connection before spawning reception tasks.
- while let Ok(InternalSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await {
+ while let Ok(ServerSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await {
if id == client_id {
break;
}
}
- let _client_close_wait = {
+ // Spawn a task to listen for the underlying connection being closed
+ {
let conn = connection.clone();
- let close_sender = client_close_send.clone();
let to_sync_server = to_sync_server_send.clone();
tokio::spawn(async move {
let conn_err = conn.closed().await;
- info!("Client {} disconnected: {}", client_id, conn_err);
- close_sender.send(()).ok();
+ info!("Client {} connection closed: {}", client_id, conn_err);
to_sync_server
- .send(InternalAsyncMessage::ClientLostConnection(client_id))
+ .send(ServerAsyncMessage::ClientConnectionClosed(
+ client_id, conn_err,
+ ))
.await
.expect("Failed to signal connection lost to sync server");
});
};
// Spawn a task to listen for streams opened by this client
- let _client_receiver = tokio::spawn(async move {
- client_receiver_task(
- client_id,
+ {
+ let conn = connection.clone();
+ let client_close_recv = client_close_recv.resubscribe();
+ tokio::spawn(async move {
+ client_receiver_task(
+ client_id,
+ conn,
+ client_close_recv,
+ payloads_from_clients_send,
+ )
+ .await
+ });
+ }
+
+ // Spawn a task to handle channels for this client
+ tokio::spawn(async move {
+ channels_task(
connection,
- client_close_send.subscribe(),
- payloads_from_clients_send,
+ client_close_recv,
+ to_channels_recv,
+ from_channels_send,
)
.await
});
}
-async fn client_sender_task(
- client_id: ClientId,
- connection: quinn::Connection,
- mut to_client_receiver: tokio::sync::mpsc::Receiver<Bytes>,
- mut client_close_recv: tokio::sync::broadcast::Receiver<()>,
- close_sender: tokio::sync::broadcast::Sender<()>,
- to_sync_server_send: mpsc::Sender<InternalAsyncMessage>,
-) {
- let send_stream = connection.open_uni().await.expect(
- format!(
- "Failed to open unidirectional send stream for client: {}",
- client_id
- )
- .as_str(),
- );
-
- let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
-
- tokio::select! {
- _ = client_close_recv.recv() => {
- trace!("Unidirectional send stream forced to disconnected for client: {}", client_id)
- }
- _ = async {
- while let Some(msg_bytes) = to_client_receiver.recv().await {
- send_msg(
- client_id,
- &close_sender,
- &to_sync_server_send,
- &mut framed_send_stream,
- msg_bytes,
- )
- .await
- }
- } => {}
- }
- while let Ok(msg_bytes) = to_client_receiver.try_recv() {
- if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
- error!("Error while sending to client {}: {}", client_id, err);
- };
- }
- if let Err(err) = framed_send_stream.flush().await {
- error!(
- "Error while flushing stream to client {}: {}",
- client_id, err
- );
- }
- if let Err(err) = framed_send_stream.into_inner().finish().await {
- error!(
- "Failed to shutdown stream gracefully for client {}: {}",
- client_id, err
- );
- }
- connection.close(VarInt::from_u32(0), "closed".as_bytes());
-}
-
-async fn send_msg(
- client_id: ClientId,
- close_sender: &tokio::sync::broadcast::Sender<()>,
- to_sync_server: &mpsc::Sender<InternalAsyncMessage>,
- framed_send_stream: &mut FramedWrite<SendStream, LengthDelimitedCodec>,
- msg_bytes: Bytes,
-) {
- // TODO Perf: Batch frames for a send_all
- if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
- error!("Error while sending to client {}: {}", client_id, err);
- error!("Client {} seems disconnected, closing resources", client_id);
- // Emit ClientLostConnection to properly update the server about this client state.
- // Raise ClientLostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed.
- to_sync_server
- .send(InternalAsyncMessage::ClientLostConnection(client_id))
- .await
- .expect("Failed to signal connection lost to sync server");
- if let Err(_) = close_sender.send(()) {
- error!(
- "Failed to close all client streams & resources for client {}",
- client_id
- )
- }
- };
-}
-
async fn client_receiver_task(
client_id: ClientId,
connection: quinn::Connection,
@@ -654,25 +819,41 @@ fn update_sync_server(
if let Some(endpoint) = server.get_endpoint_mut() {
while let Ok(message) = endpoint.from_async_server_recv.try_recv() {
match message {
- InternalAsyncMessage::ClientConnected(connection) => {
+ ServerAsyncMessage::ClientConnected(connection) => {
let id = connection.client_id;
- endpoint.clients.insert(id, connection);
- endpoint
- .to_async_server_send
- .send(InternalSyncMessage::ClientConnectedAck(id))
- .unwrap();
- connection_events.send(ConnectionEvent { id: id });
+ match endpoint.handle_connection(connection) {
+ Ok(_) => connection_events.send(ConnectionEvent { id }),
+ Err(_) => error!("Failed to handle connection of client {}", id),
+ };
}
- InternalAsyncMessage::ClientLostConnection(client_id) => {
- match endpoint.clients.remove(&client_id) {
- Some(_) => {
+ ServerAsyncMessage::ClientConnectionClosed(client_id, _) => {
+ match endpoint.clients.contains_key(&client_id) {
+ true => {
+ endpoint.try_disconnect_client(client_id);
connection_lost_events.send(ConnectionLostEvent { id: client_id })
}
- None => (),
+ false => (),
+ }
+ }
+ }
+ }
+
+ let mut lost_clients = HashSet::new();
+ for (client_id, connection) in endpoint.clients.iter_mut() {
+ while let Ok(message) = connection.from_channels_recv.try_recv() {
+ match message {
+ ChannelAsyncMessage::LostConnection => {
+ if !lost_clients.contains(client_id) {
+ lost_clients.insert(*client_id);
+ connection_lost_events.send(ConnectionLostEvent { id: *client_id })
+ }
}
}
}
}
+ for client_id in lost_clients {
+ endpoint.try_disconnect_client(client_id);
+ }
}
}