diff options
-rw-r--r-- | src/client.rs | 172 | ||||
-rw-r--r-- | src/client/certificate.rs | 18 | ||||
-rw-r--r-- | src/shared/channel.rs | 101 |
3 files changed, 163 insertions, 128 deletions
diff --git a/src/client.rs b/src/client.rs index 3e23d4a..97b5943 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use std::{ collections::{ hash_map::{Iter, IterMut}, - HashMap, HashSet, + HashMap, }, error::Error, net::SocketAddr, @@ -11,7 +11,7 @@ use std::{ use bevy::prelude::*; use bytes::Bytes; use futures_util::StreamExt; -use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, VarInt}; +use quinn::{ClientConfig, Connection as QuinnConnection, ConnectionError, Endpoint, VarInt}; use quinn_proto::ConnectionStats; use serde::Deserialize; use tokio::{ @@ -30,8 +30,8 @@ use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use crate::shared::{ channel::{ - ordered_reliable_channel_task, unordered_reliable_channel_task, Channel, ChannelId, - ChannelType, MultiChannelId, + ordered_reliable_channel_task, unordered_reliable_channel_task, Channel, + ChannelAsyncMessage, ChannelId, ChannelSyncMessage, ChannelType, MultiChannelId, }, AsyncRuntime, QuinnetError, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, }; @@ -114,18 +114,9 @@ enum ConnectionState { } #[derive(Debug)] -pub(crate) enum InternalSyncMessage { - CreateChannel { - channel_id: ChannelId, - to_server_receiver: mpsc::Receiver<Bytes>, - channel_close_receiver: mpsc::Receiver<()>, - }, -} - -#[derive(Debug)] -pub(crate) enum InternalAsyncMessage { +pub(crate) enum ClientAsyncMessage { Connected(InternalConnectionRef), - LostConnection, + ConnectionClosed(ConnectionError), CertificateInteractionRequest { status: CertVerificationStatus, info: CertVerificationInfo, @@ -137,15 +128,15 @@ pub(crate) enum InternalAsyncMessage { cert_info: CertVerificationInfo, }, } - #[derive(Debug)] pub(crate) struct ConnectionSpawnConfig { connection_config: ConnectionConfiguration, cert_mode: CertificateVerificationMode, - from_sync_client: mpsc::Receiver<InternalSyncMessage>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, - close_receiver: broadcast::Receiver<()>, - from_server_sender: mpsc::Sender<Bytes>, + to_sync_client_send: mpsc::Sender<ClientAsyncMessage>, + to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, + close_recv: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, } #[derive(Debug)] @@ -154,11 +145,12 @@ pub struct Connection { channels: HashMap<ChannelId, Channel>, default_channel: Option<ChannelId>, last_gen_id: MultiChannelId, - receiver: mpsc::Receiver<Bytes>, + bytes_from_server_recv: mpsc::Receiver<Bytes>, close_sender: broadcast::Sender<()>, - pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>, - pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>, + pub(crate) from_async_client_recv: mpsc::Receiver<ClientAsyncMessage>, + pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, + pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, } impl Connection { @@ -249,7 +241,7 @@ impl Connection { pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> { match &self.state { ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), - _ => match self.receiver.try_recv() { + _ => match self.bytes_from_server_recv.try_recv() { Ok(msg_payload) => Ok(Some(msg_payload)), Err(err) => match err { TryRecvError::Empty => Ok(None), @@ -349,20 +341,20 @@ impl Connection { } fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> { - let (to_server_sender, to_server_receiver) = + let (bytes_to_server_send, bytes_from_channel_recv) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (channel_close_sender, channel_close_receiver) = + let (channel_close_send, channel_close_recv) = mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); match self - .internal_sender - .try_send(InternalSyncMessage::CreateChannel { + .to_channels_send + .try_send(ChannelSyncMessage::CreateChannel { channel_id, - to_server_receiver, - channel_close_receiver, + bytes_to_channel_recv: bytes_from_channel_recv, + channel_close_recv, }) { Ok(_) => { - let channel = Channel::new(to_server_sender, channel_close_sender); + let channel = Channel::new(bytes_to_server_send, channel_close_send); self.channels.insert(channel_id, channel); if self.default_channel.is_none() { self.default_channel = Some(channel_id); @@ -443,13 +435,15 @@ impl Client { config: ConnectionConfiguration, cert_mode: CertificateVerificationMode, ) -> Result<ConnectionId, QuinnetError> { - let (from_server_sender, from_server_receiver) = + let (bytes_from_server_send, bytes_from_server_recv) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (to_sync_client, from_async_client) = - mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_client, from_sync_client) = - mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_sync_client_send, from_async_client_recv) = + mpsc::channel::<ClientAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (from_channels_send, from_channels_recv) = + mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_channels_send, to_channels_recv) = + mpsc::channel::<ChannelSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); // Create a close channel for this connection let (close_sender, close_receiver) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); @@ -459,10 +453,11 @@ impl Client { channels: HashMap::new(), last_gen_id: 0, default_channel: None, - receiver: from_server_receiver, + bytes_from_server_recv, close_sender: close_sender.clone(), - internal_receiver: from_async_client, - internal_sender: to_async_client, + from_async_client_recv, + to_channels_send, + from_channels_recv, }; // Create default channels connection.open_channel(ChannelType::OrderedReliable)?; @@ -474,10 +469,11 @@ impl Client { connection_task(ConnectionSpawnConfig { connection_config: config, cert_mode, - from_sync_client, - to_sync_client, - close_receiver, - from_server_sender, + to_channels_recv, + from_channels_send, + to_sync_client_send, + close_recv: close_receiver, + bytes_from_server_send, }) .await }); @@ -531,7 +527,7 @@ impl Client { fn configure_client( cert_mode: CertificateVerificationMode, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + to_sync_client: mpsc::Sender<ClientAsyncMessage>, ) -> Result<ClientConfig, Box<dyn Error>> { match cert_mode { CertificateVerificationMode::SkipVerification => { @@ -573,8 +569,11 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { .parse() .expect("Failed to parse server address"); - let client_cfg = configure_client(spawn_config.cert_mode, spawn_config.to_sync_client.clone()) - .expect("Failed to configure client"); + let client_cfg = configure_client( + spawn_config.cert_mode, + spawn_config.to_sync_client_send.clone(), + ) + .expect("Failed to configure client"); let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) .expect("Failed to create client endpoint"); @@ -590,40 +589,40 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { info!("Connected to {}", connection.remote_address()); spawn_config - .to_sync_client - .send(InternalAsyncMessage::Connected(connection.clone())) + .to_sync_client_send + .send(ClientAsyncMessage::Connected(connection.clone())) .await .expect("Failed to signal connection to sync client"); let _close_waiter = { let conn = connection.clone(); - let to_sync_client = spawn_config.to_sync_client.clone(); + let to_sync_client = spawn_config.to_sync_client_send.clone(); tokio::spawn(async move { let conn_err = conn.closed().await; info!("Disconnected: {}", conn_err); to_sync_client - .send(InternalAsyncMessage::LostConnection) + .send(ClientAsyncMessage::ConnectionClosed(conn_err)) .await .expect("Failed to signal connection lost to sync client"); }) }; - let close_receiver = spawn_config.close_receiver.resubscribe(); + let close_recv = spawn_config.close_recv.resubscribe(); let connection_handle = connection.clone(); tokio::spawn(async move { connection_receiving_task( connection_handle, - spawn_config.from_server_sender, - close_receiver, + spawn_config.bytes_from_server_send, + close_recv, ) .await }); handle_connection_channels( connection, - spawn_config.close_receiver, - spawn_config.from_sync_client, - spawn_config.to_sync_client, + spawn_config.close_recv, + spawn_config.to_channels_recv, + spawn_config.from_channels_send, ) .await; } @@ -632,18 +631,18 @@ async fn connection_task(spawn_config: ConnectionSpawnConfig) { async fn connection_receiving_task( connection: quinn::Connection, - from_server_sender: mpsc::Sender<Bytes>, - mut close_receiver: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, + mut close_recv: broadcast::Receiver<()>, ) { let mut uni_receivers: JoinSet<()> = JoinSet::new(); tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Listener for new Unidirectional Receiving Streams received a close signal") } _ = async { while let Ok(recv) = connection.accept_uni().await { let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - let from_server_sender = from_server_sender.clone(); + let from_server_sender = bytes_from_server_send.clone(); uni_receivers.spawn(async move { while let Some(Ok(msg_bytes)) = frame_recv.next().await { // TODO Clean: error handling @@ -661,25 +660,25 @@ async fn connection_receiving_task( async fn handle_connection_channels( connection: quinn::Connection, - mut close_receiver: broadcast::Receiver<()>, - mut from_sync_client: mpsc::Receiver<InternalSyncMessage>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + mut close_recv: broadcast::Receiver<()>, + mut to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, ) { // Use an mpsc channel where, instead of sending messages, we wait for the channel to be closed, which happens when every sender has been dropped. We can't use a JoinSet as simply here since we would also need to drain closed channels from it. let (channel_tasks_keepalive, mut channel_tasks_waiter) = mpsc::channel(1); - let close_receiver_clone = close_receiver.resubscribe(); + let close_receiver_clone = close_recv.resubscribe(); tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Connection Channels listener received a close signal") } _ = async { - while let Some(sync_message) = from_sync_client.recv().await { - let InternalSyncMessage::CreateChannel{ channel_id, to_server_receiver, channel_close_receiver } = sync_message; + while let Some(sync_message) = to_channels_recv.recv().await { + let ChannelSyncMessage::CreateChannel{ channel_id, bytes_to_channel_recv, channel_close_recv } = sync_message; let close_receiver = close_receiver_clone.resubscribe(); let connection_handle = connection.clone(); - let to_sync_client = to_sync_client.clone(); + let from_channels_send = from_channels_send.clone(); let channels_keepalive_clone = channel_tasks_keepalive.clone(); match channel_id { @@ -688,11 +687,10 @@ async fn handle_connection_channels( ordered_reliable_channel_task( connection_handle, channels_keepalive_clone, - to_sync_client, - || InternalAsyncMessage::LostConnection, + from_channels_send, close_receiver, - channel_close_receiver, - to_server_receiver + channel_close_recv, + bytes_to_channel_recv ) .await }); @@ -702,11 +700,10 @@ async fn handle_connection_channels( unordered_reliable_channel_task( connection_handle, channels_keepalive_clone, - to_sync_client, - || InternalAsyncMessage::LostConnection, + from_channels_send, close_receiver, - channel_close_receiver, - to_server_receiver + channel_close_recv, + bytes_to_channel_recv ) .await }); @@ -738,20 +735,20 @@ fn update_sync_client( mut client: ResMut<Client>, ) { for (connection_id, mut connection) in &mut client.connections { - while let Ok(message) = connection.internal_receiver.try_recv() { + while let Ok(message) = connection.from_async_client_recv.try_recv() { match message { - InternalAsyncMessage::Connected(internal_connection) => { + ClientAsyncMessage::Connected(internal_connection) => { connection.state = ConnectionState::Connected(internal_connection); connection_events.send(ConnectionEvent { id: *connection_id }); } - InternalAsyncMessage::LostConnection => match connection.state { + ClientAsyncMessage::ConnectionClosed(_) => match connection.state { ConnectionState::Disconnected => (), _ => { connection.try_disconnect(); connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); } }, - InternalAsyncMessage::CertificateInteractionRequest { + ClientAsyncMessage::CertificateInteractionRequest { status, info, action_sender, @@ -763,13 +760,13 @@ fn update_sync_client( action_sender: Mutex::new(Some(action_sender)), }); } - InternalAsyncMessage::CertificateTrustUpdate(info) => { + ClientAsyncMessage::CertificateTrustUpdate(info) => { cert_trust_update_events.send(CertTrustUpdateEvent { connection_id: *connection_id, cert_info: info, }); } - InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => { + ClientAsyncMessage::CertificateConnectionAbort { status, cert_info } => { cert_connection_abort_events.send(CertConnectionAbortEvent { connection_id: *connection_id, status, @@ -778,6 +775,17 @@ fn update_sync_client( } } } + while let Ok(message) = connection.from_channels_recv.try_recv() { + match message { + ChannelAsyncMessage::LostConnection => match connection.state { + ConnectionState::Disconnected => (), + _ => { + connection.try_disconnect(); + connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); + } + }, + } + } } } diff --git a/src/client/certificate.rs b/src/client/certificate.rs index 76bd52d..a6d9b8c 100644 --- a/src/client/certificate.rs +++ b/src/client/certificate.rs @@ -15,7 +15,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::shared::{CertificateFingerprint, QuinnetError}; -use super::{ConnectionId, InternalAsyncMessage, DEFAULT_KNOWN_HOSTS_FILE}; +use super::{ClientAsyncMessage, ConnectionId, DEFAULT_KNOWN_HOSTS_FILE}; pub const DEFAULT_CERT_VERIFIER_BEHAVIOUR: CertVerifierBehaviour = CertVerifierBehaviour::ImmediateAction(CertVerifierAction::AbortConnection); @@ -211,7 +211,7 @@ impl rustls::client::ServerCertVerifier for SkipServerVerification { pub(crate) struct TofuServerVerification { store: CertStore, verifier_behaviour: HashMap<CertVerificationStatus, CertVerifierBehaviour>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + to_sync_client: mpsc::Sender<ClientAsyncMessage>, /// If present, the file where new fingerprints should be stored hosts_file: Option<String>, @@ -221,7 +221,7 @@ impl TofuServerVerification { pub(crate) fn new( store: CertStore, verifier_behaviour: HashMap<CertVerificationStatus, CertVerifierBehaviour>, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, + to_sync_client: mpsc::Sender<ClientAsyncMessage>, hosts_file: Option<String>, ) -> Arc<Self> { Arc::new(Self { @@ -248,7 +248,7 @@ impl TofuServerVerification { CertVerifierBehaviour::RequestClientAction => { let (action_sender, cert_action_recv) = oneshot::channel::<CertVerifierAction>(); self.to_sync_client - .try_send(InternalAsyncMessage::CertificateInteractionRequest { + .try_send(ClientAsyncMessage::CertificateInteractionRequest { status: status.clone(), info: cert_info.clone(), action_sender, @@ -273,12 +273,12 @@ impl TofuServerVerification { ) -> Result<rustls::client::ServerCertVerified, rustls::Error> { match action { CertVerifierAction::AbortConnection => { - match self.to_sync_client.try_send( - InternalAsyncMessage::CertificateConnectionAbort { + match self + .to_sync_client + .try_send(ClientAsyncMessage::CertificateConnectionAbort { status: status, cert_info, - }, - ) { + }) { Ok(_) => Err(rustls::Error::InvalidCertificateData(format!( "CertVerifierAction requested to abort the connection" ))), @@ -304,7 +304,7 @@ impl TofuServerVerification { // In all cases raise an event containing the new certificate entry match self .to_sync_client - .try_send(InternalAsyncMessage::CertificateTrustUpdate(cert_info)) + .try_send(ClientAsyncMessage::CertificateTrustUpdate(cert_info)) { Ok(_) => Ok(rustls::client::ServerCertVerified::assertion()), Err(_) => Err(rustls::Error::General(format!( diff --git a/src/shared/channel.rs b/src/shared/channel.rs index 72dca41..759e1b8 100644 --- a/src/shared/channel.rs +++ b/src/shared/channel.rs @@ -33,6 +33,20 @@ impl std::fmt::Display for ChannelId { } #[derive(Debug)] +pub(crate) enum ChannelAsyncMessage { + LostConnection, +} + +#[derive(Debug)] +pub(crate) enum ChannelSyncMessage { + CreateChannel { + channel_id: ChannelId, + bytes_to_channel_recv: mpsc::Receiver<Bytes>, + channel_close_recv: mpsc::Receiver<()>, + }, +} + +#[derive(Debug)] pub struct Channel { sender: mpsc::Sender<Bytes>, close_sender: mpsc::Sender<()>, @@ -67,30 +81,29 @@ impl Channel { } } -pub(crate) async fn ordered_reliable_channel_task<T: Debug>( +pub(crate) async fn ordered_reliable_channel_task( connection: quinn::Connection, _: mpsc::Sender<()>, - to_sync_client: mpsc::Sender<T>, - on_lost_connection: fn() -> T, - mut close_receiver: broadcast::Receiver<()>, - mut channel_close_receiver: mpsc::Receiver<()>, - mut to_server_receiver: mpsc::Receiver<Bytes>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, + mut close_recv: broadcast::Receiver<()>, + mut channel_close_recv: mpsc::Receiver<()>, + mut bytes_to_channel_recv: mpsc::Receiver<Bytes>, ) { let mut frame_sender = new_uni_frame_sender(&connection).await; tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Ordered Reliable Channel task received a close signal") } - _ = channel_close_receiver.recv() => { + _ = channel_close_recv.recv() => { trace!("Ordered Reliable Channel task received a channel close signal") } _ = async { - while let Some(msg_bytes) = to_server_receiver.recv().await { + while let Some(msg_bytes) = bytes_to_channel_recv.recv().await { if let Err(err) = frame_sender.send(msg_bytes).await { error!("Error while sending, {}", err); - to_sync_client.send( - on_lost_connection()) + from_channels_send.send( + ChannelAsyncMessage::LostConnection) .await .expect("Failed to signal connection lost to sync client"); } @@ -99,7 +112,7 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>( trace!("Ordered Reliable Channel task ended") } }; - while let Ok(msg_bytes) = to_server_receiver.try_recv() { + while let Ok(msg_bytes) = bytes_to_channel_recv.try_recv() { if let Err(err) = frame_sender.send(msg_bytes).await { error!("Error while sending, {}", err); } @@ -112,44 +125,58 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>( } } -pub(crate) async fn unordered_reliable_channel_task<T: Debug>( +pub(crate) async fn unordered_reliable_channel_task( connection: quinn::Connection, - _: mpsc::Sender<()>, - to_sync_client: mpsc::Sender<T>, - on_lost_connection: fn() -> T, - mut close_receiver: broadcast::Receiver<()>, - mut channel_close_receiver: mpsc::Receiver<()>, - mut to_server_receiver: mpsc::Receiver<Bytes>, + channel_tasks_keepalive: mpsc::Sender<()>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, + mut close_recv: broadcast::Receiver<()>, + mut channel_close_recv: mpsc::Receiver<()>, + mut bytes_to_channel_recv: mpsc::Receiver<Bytes>, ) { tokio::select! { - _ = close_receiver.recv() => { + _ = close_recv.recv() => { trace!("Ordered Reliable Channel task received a close signal") } - _ = channel_close_receiver.recv() => { + _ = channel_close_recv.recv() => { trace!("Unordered Reliable Channel task received a channel close signal") } _ = async { - while let Some(msg_bytes) = to_server_receiver.recv().await { - let mut frame_sender = new_uni_frame_sender(&connection).await; - if let Err(err) = frame_sender.send(msg_bytes).await { - error!("Error while sending, {}", err); - to_sync_client.send( - on_lost_connection()) - .await - .expect("Failed to signal connection lost to sync client"); - } - todo!("finish the stream") + while let Some(msg_bytes) = bytes_to_channel_recv.recv().await { + let conn = connection.clone(); + let to_sync_client_clone = from_channels_send.clone(); + let channels_keepalive_clone = channel_tasks_keepalive.clone(); + tokio::spawn(async move { + let mut frame_sender = new_uni_frame_sender(&conn).await; + if let Err(err) = frame_sender.send(msg_bytes).await { + error!("Error while sending, {}", err); + to_sync_client_clone.send( + ChannelAsyncMessage::LostConnection) + .await + .expect("Failed to signal connection lost to sync client"); + } + if let Err(err) = frame_sender.into_inner().finish().await { + error!("Failed to shutdown stream gracefully: {}", err); + } + drop(channels_keepalive_clone) + }); } } => { trace!("Unordered Reliable Channel task ended") } }; - while let Ok(msg_bytes) = to_server_receiver.try_recv() { - let mut frame_sender = new_uni_frame_sender(&connection).await; - if let Err(err) = frame_sender.send(msg_bytes).await { - error!("Error while sending, {}", err); - } - todo!("finish the stream") + while let Ok(msg_bytes) = bytes_to_channel_recv.try_recv() { + let conn = connection.clone(); + let channels_keepalive_clone = channel_tasks_keepalive.clone(); + tokio::spawn(async move { + let mut frame_sender = new_uni_frame_sender(&conn).await; + if let Err(err) = frame_sender.send(msg_bytes).await { + error!("Error while sending, {}", err); + } + if let Err(err) = frame_sender.into_inner().finish().await { + error!("Failed to shutdown stream gracefully: {}", err); + } + drop(channels_keepalive_clone) + }); } } |