diff options
author | gilles henaux <gill.henaux@gmail.com> | 2023-01-12 18:05:32 +0100 |
---|---|---|
committer | gilles henaux <gill.henaux@gmail.com> | 2023-01-12 18:05:32 +0100 |
commit | 0ce0557081d7355c3b3d2e8e266ca394e036e03e (patch) | |
tree | 1da7bac6b4c6225ff89dd49fe3834c0d428e280d /src/client.rs | |
parent | d115261d0d2a9feaace1d5ba4e9e2fc9819994cb (diff) |
[channels] Properly signal flush/
termination for all Reliable channels & Internal messages refactor
Diffstat (limited to 'src/client.rs')
-rw-r--r-- | src/client.rs | 172 |
1 files changed, 90 insertions, 82 deletions
diff --git a/src/client.rs b/src/client.rs index 3e23d4a..97b5943 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use std::{ collections::{ hash_map::{Iter, IterMut}, - HashMap, HashSet, + HashMap, }, error::Error, net::SocketAddr, @@ -11,7 +11,7 @@ use std::{ use bevy::prelude::*; use bytes::Bytes; use futures_util::StreamExt; -use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, VarInt}; +use quinn::{ClientConfig, Connection as QuinnConnection, ConnectionError, Endpoint, VarInt}; use quinn_proto::ConnectionStats; use serde::Deserialize; use tokio::{ @@ -30,8 +30,8 @@ use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use crate::shared::{ channel::{ - ordered_reliable_channel_task, unordered_reliable_channel_task, Channel, ChannelId, - ChannelType, MultiChannelId, + ordered_reliable_channel_task, unordered_reliable_channel_task, Channel, + ChannelAsyncMessage, ChannelId, ChannelSyncMessage, ChannelType, MultiChannelId, }, AsyncRuntime, QuinnetError, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, }; @@ -114,18 +114,9 @@ enum ConnectionState { } #[derive(Debug)] -pub(crate) enum InternalSyncMessage { - CreateChannel { - channel_id: ChannelId, - to_server_receiver: mpsc::Receiver<Bytes>, - channel_close_receiver: mpsc::Receiver<()>, - }, -} - -#[derive(Debug)] -pub(crate) enum InternalAsyncMessage { +pub(crate) enum ClientAsyncMessage { Connected(InternalConnectionRef), - LostConnection, + ConnectionClosed(ConnectionError), CertificateInteractionRequest { status: CertVerificationStatus, info: CertVerificationInfo, @@ -137,15 +128,15 @@ pub(crate) enum InternalAsyncMessage { cert_info: CertVerificationInfo, }, } - #[derive(Debug)] pub(crate) struct ConnectionSpawnConfig { connection_config: ConnectionConfiguration, cert_mode: CertificateVerificationMode, - from_sync_client: mpsc::Receiver<InternalSyncMessage>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, - close_receiver: broadcast::Receiver<()>, - from_server_sender: mpsc::Sender<Bytes>, + to_sync_client_send: mpsc::Sender<ClientAsyncMessage>, + to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, + close_recv: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, } #[derive(Debug)] @@ -154,11 +145,12 @@ pub struct Connection { channels: HashMap<ChannelId, Channel>, default_channel: Option<ChannelId>, last_gen_id: MultiChannelId, - receiver: mpsc::Receiver<Bytes>, + bytes_from_server_recv: mpsc::Receiver<Bytes>, close_sender: broadcast::Sender<()>, - pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>, - pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>, + pub(crate) from_async_client_recv: mpsc::Receiver<ClientAsyncMessage>, + pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, + pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, } impl Connection { @@ -249,7 +241,7 @@ impl Connection { pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> { match &self.state { ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), - _ => match self.receiver.try_recv() { + _ => match self.bytes_from_server_recv.try_recv() { Ok(msg_payload) => Ok(Some(msg_payload)), Err(err) => match err { TryRecvError::Empty => Ok(None), @@ -349,20 +341,20 @@ impl Connection { } fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> { - let (to_server_sender, to_server_receiver) = + let (bytes_to_server_send, bytes_from_channel_recv) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (channel_close_sender, channel_close_receiver) = + let (channel_close_send, channel_close_recv) = mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); match self - .internal_sender - .try_send(InternalSyncMessage::CreateChannel { + .to_channels_send + .try_send(ChannelSyncMessage::CreateChannel { channel_id, - to_server_receiver, - channel_close_receiver, + bytes_to_channel_recv: bytes_from_channel_recv, + channel_close_recv, }) { Ok(_) => { - let channel = Channel::new(to_server_sender, channel_close_sender); + let channel = Channel::new(bytes_to_server_send, channel_close_send); self.channels.insert(channel_id, channel); if self.default_channel.is_none() { self.default_channel = Some(channel_id); @@ -443,13 +435,15 @@ impl Client { config: ConnectionConfiguration, cert_mode: CertificateVerificationMode, ) -> Result<ConnectionId, QuinnetError> { - let (from_server_sender, from_server_receiver) = + let (bytes_from_server_send, bytes_from_server_recv) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (to_sync_client, from_async_client) = - mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_client, from_sync_client) = - mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_sync_client_send, from_async_client_recv) = + mpsc::channel::<ClientAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (from_channels_send, from_channels_recv) = + mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_channels_send, to_channels_recv) = + mpsc::channel::<ChannelSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); // Create a close channel for this connection let (close_sender, close_receiver) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); @@ -459,10 +453,11 @@ impl Client { channels: HashMap::new(), last_gen_id: 0, default_channel: None, - receiver: from_server_receiver, + bytes_from_server_recv, close_sender: close_sender.clone(), - internal_receiver: from_async_client, - internal_sender: to_async_client, + from_async_client_recv, + to_channels_send, + from_channels_recv, }; // Create default channels connection.open_channel(ChannelType::OrderedReliable)?; @@ -474,10 +469,11 @@ impl Client { connection_task(ConnectionSpawnConfig { connection_config: config, cert_mode, - from_sync_client, - to_sync_client, - close_receiver, - from_server_sender, + to_channels_recv, + from_channels_send, + to_sync_client_send, + close_recv: close_receiver, + bytes_from_server_send, }) .await }); @@ -531,7 +527,7 @@ impl Client { fn configure_client( cert_mode: CertificateVerificationMode, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + to_sync_client: mpsc::Sender<ClientAsyncMessage>, ) -> Result<ClientConfig, Box<dyn Error>> { match cert_mode { CertificateVerificationMode::SkipVerification => { @@ -573,8 +569,11 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { .parse() .expect("Failed to parse server address"); - let client_cfg = configure_client(spawn_config.cert_mode, spawn_config.to_sync_client.clone()) - .expect("Failed to configure client"); + let client_cfg = configure_client( + spawn_config.cert_mode, + spawn_config.to_sync_client_send.clone(), + ) + .expect("Failed to configure client"); let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) .expect("Failed to create client endpoint"); @@ -590,40 +589,40 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { info!("Connected to {}", connection.remote_address()); spawn_config - .to_sync_client - .send(InternalAsyncMessage::Connected(connection.clone())) + .to_sync_client_send + .send(ClientAsyncMessage::Connected(connection.clone())) .await .expect("Failed to signal connection to sync client"); let _close_waiter = { let conn = connection.clone(); - let to_sync_client = spawn_config.to_sync_client.clone(); + let to_sync_client = spawn_config.to_sync_client_send.clone(); tokio::spawn(async move { let conn_err = conn.closed().await; info!("Disconnected: {}", conn_err); to_sync_client - .send(InternalAsyncMessage::LostConnection) + .send(ClientAsyncMessage::ConnectionClosed(conn_err)) .await .expect("Failed to signal connection lost to sync client"); }) }; - let close_receiver = spawn_config.close_receiver.resubscribe(); + let close_recv = spawn_config.close_recv.resubscribe(); let connection_handle = connection.clone(); tokio::spawn(async move { connection_receiving_task( connection_handle, - spawn_config.from_server_sender, - close_receiver, + spawn_config.bytes_from_server_send, + close_recv, ) .await }); handle_connection_channels( connection, - spawn_config.close_receiver, - spawn_config.from_sync_client, - spawn_config.to_sync_client, + spawn_config.close_recv, + spawn_config.to_channels_recv, + spawn_config.from_channels_send, ) .await; } @@ -632,18 +631,18 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { async fn connection_receiving_task( connection: quinn::Connection, - from_server_sender: mpsc::Sender<Bytes>, - mut close_receiver: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, + mut close_recv: broadcast::Receiver<()>, ) { let mut uni_receivers: JoinSet<()> = JoinSet::new(); tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Listener for new Unidirectional Receiving Streams received a close signal") } _ = async { while let Ok(recv) = connection.accept_uni().await { let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - let from_server_sender = from_server_sender.clone(); + let from_server_sender = bytes_from_server_send.clone(); uni_receivers.spawn(async move { while let Some(Ok(msg_bytes)) = frame_recv.next().await { // TODO Clean: error handling @@ -661,25 +660,25 @@ async fn connection_receiving_task( async fn handle_connection_channels( connection: quinn::Connection, - mut close_receiver: broadcast::Receiver<()>, - mut from_sync_client: mpsc::Receiver<InternalSyncMessage>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + mut close_recv: broadcast::Receiver<()>, + mut to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, ) { // Use an mpsc channel where, instead of sending messages, we wait for the channel to be closed, which happens when every sender has been dropped. We can't use a JoinSet as simply here since we would also need to drain closed channels from it. let (channel_tasks_keepalive, mut channel_tasks_waiter) = mpsc::channel(1); - let close_receiver_clone = close_receiver.resubscribe(); + let close_receiver_clone = close_recv.resubscribe(); tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Connection Channels listener received a close signal") } _ = async { - while let Some(sync_message) = from_sync_client.recv().await { - let InternalSyncMessage::CreateChannel{ channel_id, to_server_receiver, channel_close_receiver } = sync_message; + while let Some(sync_message) = to_channels_recv.recv().await { + let ChannelSyncMessage::CreateChannel{ channel_id, bytes_to_channel_recv, channel_close_recv } = sync_message; let close_receiver = close_receiver_clone.resubscribe(); let connection_handle = connection.clone(); - let to_sync_client = to_sync_client.clone(); + let from_channels_send = from_channels_send.clone(); let channels_keepalive_clone = channel_tasks_keepalive.clone(); match channel_id { @@ -688,11 +687,10 @@ async fn handle_connection_channels( ordered_reliable_channel_task( connection_handle, channels_keepalive_clone, - to_sync_client, - || InternalAsyncMessage::LostConnection, + from_channels_send, close_receiver, - channel_close_receiver, - to_server_receiver + channel_close_recv, + bytes_to_channel_recv ) .await }); @@ -702,11 +700,10 @@ async fn handle_connection_channels( unordered_reliable_channel_task( connection_handle, channels_keepalive_clone, - to_sync_client, - || InternalAsyncMessage::LostConnection, + from_channels_send, close_receiver, - channel_close_receiver, - to_server_receiver + channel_close_recv, + bytes_to_channel_recv ) .await }); @@ -738,20 +735,20 @@ fn update_sync_client( mut client: ResMut<Client>, ) { for (connection_id, mut connection) in &mut client.connections { - while let Ok(message) = connection.internal_receiver.try_recv() { + while let Ok(message) = connection.from_async_client_recv.try_recv() { match message { - InternalAsyncMessage::Connected(internal_connection) => { + ClientAsyncMessage::Connected(internal_connection) => { connection.state = ConnectionState::Connected(internal_connection); connection_events.send(ConnectionEvent { id: *connection_id }); } - InternalAsyncMessage::LostConnection => match connection.state { + ClientAsyncMessage::ConnectionClosed(_) => match connection.state { ConnectionState::Disconnected => (), _ => { connection.try_disconnect(); connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); } }, - InternalAsyncMessage::CertificateInteractionRequest { + ClientAsyncMessage::CertificateInteractionRequest { status, info, action_sender, @@ -763,13 +760,13 @@ fn update_sync_client( action_sender: Mutex::new(Some(action_sender)), }); } - InternalAsyncMessage::CertificateTrustUpdate(info) => { + ClientAsyncMessage::CertificateTrustUpdate(info) => { cert_trust_update_events.send(CertTrustUpdateEvent { connection_id: *connection_id, cert_info: info, }); } - InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => { + ClientAsyncMessage::CertificateConnectionAbort { status, cert_info } => { cert_connection_abort_events.send(CertConnectionAbortEvent { connection_id: *connection_id, status, @@ -778,6 +775,17 @@ fn update_sync_client( } } } + while let Ok(message) = connection.from_channels_recv.try_recv() { + match message { + ChannelAsyncMessage::LostConnection => match connection.state { + ConnectionState::Disconnected => (), + _ => { + connection.try_disconnect(); + connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); + } + }, + } + } } } |