diff options
author | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2023-01-17 20:42:50 +0100 |
---|---|---|
committer | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2023-01-17 20:42:50 +0100 |
commit | 1e7ef10aeeac969eb7253d86370cbc1ad7548fc9 (patch) | |
tree | 810f3a7698b0ad6231dfafb218388b4662879813 /src | |
parent | 4e9fc77604a9ef1745fb3ee08eb5bc66cb2711b3 (diff) |
[client] Move client::connection into a separate file
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 560 | ||||
-rw-r--r-- | src/client/connection.rs | 544 | ||||
-rw-r--r-- | src/lib.rs | 5 | ||||
-rw-r--r-- | src/shared.rs | 2 |
4 files changed, 571 insertions, 540 deletions
diff --git a/src/client.rs b/src/client.rs index 8a52441..51cc0ea 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,115 +3,43 @@ use std::{ hash_map::{Iter, IterMut}, HashMap, }, - error::Error, - net::SocketAddr, - sync::{Arc, Mutex}, + sync::Mutex, }; use bevy::prelude::*; use bytes::Bytes; -use futures_util::StreamExt; -use quinn::{ClientConfig, Connection as QuinnConnection, ConnectionError, Endpoint, RecvStream}; -use quinn_proto::ConnectionStats; -use serde::Deserialize; +use quinn::ConnectionError; use tokio::{ runtime::{self}, sync::{ broadcast, - mpsc::{ - self, - error::{TryRecvError, TrySendError}, - }, + mpsc::{self}, oneshot, }, }; -use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use crate::shared::{ - channel::{ - channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId, - ChannelSyncMessage, ChannelType, MultiChannelId, - }, + channel::{ChannelAsyncMessage, ChannelId, ChannelSyncMessage, ChannelType}, AsyncRuntime, QuinnetError, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, }; -use self::certificate::{ - load_known_hosts_store_from_config, CertConnectionAbortEvent, CertInteractionEvent, - CertTrustUpdateEvent, CertVerificationInfo, CertVerificationStatus, CertVerifierAction, - CertificateVerificationMode, SkipServerVerification, TofuServerVerification, +use self::{ + certificate::{ + CertConnectionAbortEvent, CertInteractionEvent, CertTrustUpdateEvent, CertVerificationInfo, + CertVerificationStatus, CertVerifierAction, CertificateVerificationMode, + }, + connection::{ + connection_task, Connection, ConnectionConfiguration, ConnectionEvent, ConnectionId, + ConnectionLostEvent, ConnectionState, InternalConnectionRef, + }, }; pub mod certificate; +pub mod connection; pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100; pub const DEFAULT_KNOWN_HOSTS_FILE: &str = "quinnet/known_hosts"; -pub type ConnectionId = u64; - -/// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionEvent { - pub id: ConnectionId, -} -/// ConnectionLost event raised when the client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionLostEvent { - pub id: ConnectionId, -} - -/// Configuration of the client, used when connecting to a server -#[derive(Debug, Deserialize, Clone)] -pub struct ConnectionConfiguration { - server_host: String, - server_port: u16, - local_bind_host: String, - local_bind_port: u16, -} - -impl ConnectionConfiguration { - /// Creates a new ClientConfigurationData - /// - /// # Arguments - /// - /// * `server_host` - Address of the server - /// * `server_port` - Port that the server is listening on - /// * `local_bind_host` - Local address to bind to, which should usually be a wildcard address like `0.0.0.0` or `[::]`, which allow communication with any reachable IPv4 or IPv6 address. See [`quinn::endpoint::Endpoint`] for more precision - /// * `local_bind_port` - Local port to bind to. Use 0 to get an OS-assigned port.. See [`quinn::endpoint::Endpoint`] for more precision - /// - /// # Examples - /// - /// ``` - /// use bevy_quinnet::client::ConnectionConfiguration; - /// let config = ConnectionConfiguration::new( - /// "127.0.0.1".to_string(), - /// 6000, - /// "0.0.0.0".to_string(), - /// 0, - /// ); - /// ``` - pub fn new( - server_host: String, - server_port: u16, - local_bind_host: String, - local_bind_port: u16, - ) -> Self { - Self { - server_host, - server_port, - local_bind_host, - local_bind_port, - } - } -} - -type InternalConnectionRef = QuinnConnection; - -/// Current state of a client connection -#[derive(Debug)] -enum ConnectionState { - Connecting, - Connected(InternalConnectionRef), - Disconnected, -} - #[derive(Debug)] pub(crate) enum ClientAsyncMessage { Connected(InternalConnectionRef), @@ -127,255 +55,6 @@ pub(crate) enum ClientAsyncMessage { cert_info: CertVerificationInfo, }, } -#[derive(Debug)] -pub(crate) struct ConnectionSpawnConfig { - connection_config: ConnectionConfiguration, - cert_mode: CertificateVerificationMode, - to_sync_client_send: mpsc::Sender<ClientAsyncMessage>, - to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, - from_channels_send: mpsc::Sender<ChannelAsyncMessage>, - close_recv: broadcast::Receiver<()>, - bytes_from_server_send: mpsc::Sender<Bytes>, -} - -#[derive(Debug)] -pub struct Connection { - state: ConnectionState, - channels: HashMap<ChannelId, Channel>, - default_channel: Option<ChannelId>, - last_gen_id: MultiChannelId, - bytes_from_server_recv: mpsc::Receiver<Bytes>, - close_sender: broadcast::Sender<()>, - - pub(crate) from_async_client_recv: mpsc::Receiver<ClientAsyncMessage>, - pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, - pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, -} - -impl Connection { - pub fn receive_message<T: serde::de::DeserializeOwned>( - &mut self, - ) -> Result<Option<T>, QuinnetError> { - match self.receive_payload()? { - Some(payload) => match bincode::deserialize(&payload) { - Ok(msg) => Ok(Some(msg)), - Err(_) => Err(QuinnetError::Deserialization), - }, - None => Ok(None), - } - } - - /// Same as [Connection::receive_message] but will log the error instead of returning it - pub fn try_receive_message<T: serde::de::DeserializeOwned>(&mut self) -> Option<T> { - match self.receive_message() { - Ok(message) => message, - Err(err) => { - error!("try_receive_message: {}", err); - None - } - } - } - - pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> { - match self.default_channel { - Some(channel) => self.send_message_on(channel, message), - None => Err(QuinnetError::NoDefaultChannel), - } - } - - pub fn send_message_on<T: serde::Serialize>( - &self, - channel_id: ChannelId, - message: T, - ) -> Result<(), QuinnetError> { - match &self.state { - ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), - _ => match self.channels.get(&channel_id) { - Some(channel) => match bincode::serialize(&message) { - Ok(payload) => channel.send_payload(payload), - Err(_) => Err(QuinnetError::Serialization), - }, - None => Err(QuinnetError::UnknownChannel(channel_id)), - }, - } - } - - /// Same as [Connection::send_message] but will log the error instead of returning it - pub fn try_send_message<T: serde::Serialize>(&self, message: T) { - match self.send_message(message) { - Ok(_) => {} - Err(err) => error!("try_send_message: {}", err), - } - } - - pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> { - match self.default_channel { - Some(channel) => self.send_payload_on(channel, payload), - None => Err(QuinnetError::NoDefaultChannel), - } - } - - pub fn send_payload_on<T: Into<Bytes>>( - &self, - channel_id: ChannelId, - payload: T, - ) -> Result<(), QuinnetError> { - match &self.state { - ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), - _ => match self.channels.get(&channel_id) { - Some(channel) => channel.send_payload(payload), - None => Err(QuinnetError::UnknownChannel(channel_id)), - }, - } - } - - /// Same as [Connection::send_payload] but will log the error instead of returning it - pub fn try_send_payload<T: Into<Bytes>>(&self, payload: T) { - match self.send_payload(payload) { - Ok(_) => {} - Err(err) => error!("try_send_payload: {}", err), - } - } - - pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> { - match &self.state { - ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), - _ => match self.bytes_from_server_recv.try_recv() { - Ok(msg_payload) => Ok(Some(msg_payload)), - Err(err) => match err { - TryRecvError::Empty => Ok(None), - TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), - }, - }, - } - } - - /// Same as [Connection::receive_payload] but will log the error instead of returning it - pub fn try_receive_payload(&mut self) -> Option<Bytes> { - match self.receive_payload() { - Ok(payload) => payload, - Err(err) => { - error!("try_receive_payload: {}", err); - None - } - } - } - - /// Immediately prevents new messages from being sent on the connection and signal the connection to closes all its background tasks. Before trully closing, the connection will wait for all buffered messages in all its opened channels to be properly sent according to their respective channel type. - fn disconnect(&mut self) -> Result<(), QuinnetError> { - match &self.state { - ConnectionState::Disconnected => Ok(()), - _ => { - self.state = ConnectionState::Disconnected; - match self.close_sender.send(()) { - Ok(_) => Ok(()), - Err(_) => { - // The only possible error for a send is that there is no active receivers, meaning that the tasks are already terminated. - Err(QuinnetError::ConnectionAlreadyClosed) - } - } - } - } - } - - fn try_disconnect(&mut self) { - match &self.disconnect() { - Ok(_) => (), - Err(err) => error!("Failed to properly close clonnection: {}", err), - } - } - - pub fn is_connected(&self) -> bool { - match self.state { - ConnectionState::Connected(_) => true, - _ => false, - } - } - - /// Returns statistics about the current connection if connected. - pub fn stats(&self) -> Option<ConnectionStats> { - match &self.state { - ConnectionState::Connected(connection) => Some(connection.stats()), - _ => None, - } - } - - /// Opens a channel of the requested [ChannelType] and returns its [ChannelId]. - /// - /// By default, when starting a [Connection]], Quinnet creates 1 channel instance of each [ChannelType], each with their own [ChannelId]. Among those, there is a `default` channel which will be used when you don't specify the channel. At startup, this default channel is a [ChannelType::OrderedReliable] channel. - /// - /// If no channels were previously opened, the opened channel will be the new default channel. - /// - /// Can fail if the Connection is closed. - pub fn open_channel(&mut self, channel_type: ChannelType) -> Result<ChannelId, QuinnetError> { - let channel_id = get_channel_id_from_type(channel_type, || { - self.last_gen_id += 1; - self.last_gen_id - }); - match self.channels.contains_key(&channel_id) { - true => Ok(channel_id), - false => self.create_channel(channel_id), - } - } - - /// Closes the channel with the corresponding [ChannelId]. - /// - /// No new messages will be able to be sent on this channel, however, the channel will properly try to send all the messages that were previously pushed to it, according to its [ChannelType], before fully closing. - /// - /// If the closed channel is the current default channel, the default channel gets set to `None`. - /// - /// Can fail if the [ChannelId] is unknown, or if the channel is already closed. - pub fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> { - match self.channels.remove(&channel_id) { - Some(channel) => { - if Some(channel_id) == self.default_channel { - self.default_channel = None; - } - channel.close() - } - None => Err(QuinnetError::UnknownChannel(channel_id)), - } - } - - /// Set the default channel - pub fn set_default_channel(&mut self, channel_id: ChannelId) { - self.default_channel = Some(channel_id); - } - - /// Get the default Channel Id - pub fn get_default_channel(&self) -> Option<ChannelId> { - self.default_channel - } - - fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> { - let (bytes_to_channel_send, bytes_to_channel_recv) = - mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (channel_close_send, channel_close_recv) = - mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - - match self - .to_channels_send - .try_send(ChannelSyncMessage::CreateChannel { - channel_id, - bytes_to_channel_recv, - channel_close_recv, - }) { - Ok(_) => { - let channel = Channel::new(bytes_to_channel_send, channel_close_send); - self.channels.insert(channel_id, channel); - if self.default_channel.is_none() { - self.default_channel = Some(channel_id); - } - - Ok(channel_id) - } - Err(err) => match err { - TrySendError::Full(_) => Err(QuinnetError::FullQueue), - TrySendError::Closed(_) => Err(QuinnetError::InternalChannelClosed), - }, - } - } -} #[derive(Resource)] pub struct Client { @@ -455,19 +134,15 @@ impl Client { mpsc::channel::<ChannelSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); // Create a close channel for this connection - let (close_sender, close_receiver) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + let (close_send, close_recv) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - let mut connection = Connection { - state: ConnectionState::Connecting, - channels: HashMap::new(), - last_gen_id: 0, - default_channel: None, + let mut connection = Connection::new( bytes_from_server_recv, - close_sender: close_sender.clone(), + close_send.clone(), from_async_client_recv, to_channels_send, from_channels_recv, - }; + ); // Create default channels let ordered_reliable_id = connection.open_channel(ChannelType::OrderedReliable)?; connection.open_channel(ChannelType::UnorderedReliable)?; @@ -475,15 +150,15 @@ impl Client { // Async connection self.runtime.spawn(async move { - connection_task(ConnectionSpawnConfig { - connection_config: config, + connection_task( + config, cert_mode, + to_sync_client_send, to_channels_recv, from_channels_send, - to_sync_client_send, - close_recv: close_receiver, + close_recv, bytes_from_server_send, - }) + ) .await }); @@ -534,195 +209,6 @@ impl Client { } } -fn configure_client( - cert_mode: CertificateVerificationMode, - to_sync_client: mpsc::Sender<ClientAsyncMessage>, -) -> Result<ClientConfig, Box<dyn Error>> { - match cert_mode { - CertificateVerificationMode::SkipVerification => { - let crypto = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(SkipServerVerification::new()) - .with_no_client_auth(); - - Ok(ClientConfig::new(Arc::new(crypto))) - } - CertificateVerificationMode::SignedByCertificateAuthority => { - Ok(ClientConfig::with_native_roots()) - } - CertificateVerificationMode::TrustOnFirstUse(config) => { - let (store, store_file) = load_known_hosts_store_from_config(config.known_hosts)?; - let crypto = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(TofuServerVerification::new( - store, - config.verifier_behaviour, - to_sync_client, - store_file, - )) - .with_no_client_auth(); - Ok(ClientConfig::new(Arc::new(crypto))) - } - } -} - -async fn connection_task(spawn_config: ConnectionSpawnConfig) { - let config = spawn_config.connection_config; - let server_adr_str = format!("{}:{}", config.server_host, config.server_port); - let srv_host = config.server_host.clone(); - let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port); - - info!("Trying to connect to server on: {} ...", server_adr_str); - - let server_addr: SocketAddr = server_adr_str - .parse() - .expect("Failed to parse server address"); - - let client_cfg = configure_client( - spawn_config.cert_mode, - spawn_config.to_sync_client_send.clone(), - ) - .expect("Failed to configure client"); - - let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) - .expect("Failed to create client endpoint"); - endpoint.set_default_client_config(client_cfg); - - let connection = endpoint - .connect(server_addr, &srv_host) // TODO Clean: error handling - .expect("Failed to connect: configuration error") - .await; - match connection { - Err(e) => error!("Error while connecting: {}", e), - Ok(connection) => { - info!("Connected to {}", connection.remote_address()); - - spawn_config - .to_sync_client_send - .send(ClientAsyncMessage::Connected(connection.clone())) - .await - .expect("Failed to signal connection to sync client"); - - // Spawn a task to listen for the underlying connection being closed - { - let conn = connection.clone(); - let to_sync_client = spawn_config.to_sync_client_send.clone(); - tokio::spawn(async move { - let conn_err = conn.closed().await; - info!("Disconnected: {}", conn_err); - // If we requested the connection to close, channel may have been closed already. - if !to_sync_client.is_closed() { - to_sync_client - .send(ClientAsyncMessage::ConnectionClosed(conn_err)) - .await - .expect("Failed to signal connection lost to sync client"); - } - }) - }; - - // Spawn a task to listen for streams opened by the server - { - let close_recv = spawn_config.close_recv.resubscribe(); - let connection_handle = connection.clone(); - let bytes_incoming_send = spawn_config.bytes_from_server_send.clone(); - tokio::spawn(async move { - reliable_receiver_task(connection_handle, close_recv, bytes_incoming_send).await - }); - } - - // Spawn a task to listen for datagrams sent by the server - { - let close_recv = spawn_config.close_recv.resubscribe(); - let connection_handle = connection.clone(); - let bytes_incoming_send = spawn_config.bytes_from_server_send.clone(); - tokio::spawn(async move { - unreliable_receiver_task(connection_handle, close_recv, bytes_incoming_send) - .await - }); - } - - // Spawn a task to handle send channels for this connection - tokio::spawn(async move { - channels_task( - connection, - spawn_config.close_recv, - spawn_config.to_channels_recv, - spawn_config.from_channels_send, - ) - .await - }); - } - } -} - -async fn uni_receiver_task( - recv: RecvStream, - mut close_recv: broadcast::Receiver<()>, - bytes_from_server_send: mpsc::Sender<Bytes>, -) { - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener of a Unidirectional Receiving Streams received a close signal") - } - _ = async { - let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - while let Some(Ok(msg_bytes)) = frame_recv.next().await { - // TODO Clean: error handling - bytes_from_server_send.send(msg_bytes.into()).await.unwrap(); - } - } => {} - }; -} - -async fn reliable_receiver_task( - connection: quinn::Connection, - mut close_recv: broadcast::Receiver<()>, - bytes_incoming_send: mpsc::Sender<Bytes>, -) { - let close_recv_clone = close_recv.resubscribe(); - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener for new Unidirectional Receiving Streams received a close signal") - } - _ = async { - while let Ok(recv) = connection.accept_uni().await { - let bytes_from_server_send = bytes_incoming_send.clone(); - let close_recv_clone = close_recv_clone.resubscribe(); - tokio::spawn(async move { - uni_receiver_task( - recv, - close_recv_clone, - bytes_from_server_send - ).await; - }); - } - } => { - trace!("Listener for new Unidirectional Receiving Streams ended") - } - }; - trace!("All unidirectional stream receivers cleaned"); -} - -async fn unreliable_receiver_task( - connection: quinn::Connection, - mut close_recv: broadcast::Receiver<()>, - bytes_incoming_send: mpsc::Sender<Bytes>, -) { - tokio::select! { - _ = close_recv.recv() => { - trace!("Listener for unreliable datagrams received a close signal") - } - _ = async { - while let Ok(msg_bytes) = connection.read_datagram().await { - // TODO Clean: error handling - bytes_incoming_send.send(msg_bytes.into()).await.unwrap(); - } - } => { - trace!("Listener for unreliable datagrams ended") - } - }; -} - // Receive messages from the async client tasks and update the sync client. fn update_sync_client( mut connection_events: EventWriter<ConnectionEvent>, diff --git a/src/client/connection.rs b/src/client/connection.rs new file mode 100644 index 0000000..94bcf2d --- /dev/null +++ b/src/client/connection.rs @@ -0,0 +1,544 @@ +use std::{collections::HashMap, error::Error, net::SocketAddr, sync::Arc}; + +use bevy::prelude::{error, info, trace}; +use bytes::Bytes; +use futures::StreamExt; +use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, RecvStream}; +use quinn_proto::ConnectionStats; + +use serde::Deserialize; +use tokio::sync::{ + broadcast, + mpsc::{ + self, + error::{TryRecvError, TrySendError}, + }, +}; +use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; + +use crate::shared::{ + channel::{ + channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId, + ChannelSyncMessage, ChannelType, MultiChannelId, + }, + QuinnetError, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE, +}; + +use super::{ + certificate::{ + load_known_hosts_store_from_config, CertificateVerificationMode, SkipServerVerification, + TofuServerVerification, + }, + ClientAsyncMessage, +}; + +pub type ConnectionId = u64; + +/// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage. +pub struct ConnectionEvent { + pub id: ConnectionId, +} +/// ConnectionLost event raised when the client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage. +pub struct ConnectionLostEvent { + pub id: ConnectionId, +} + +/// Configuration of the client, used when connecting to a server +#[derive(Debug, Deserialize, Clone)] +pub struct ConnectionConfiguration { + server_host: String, + server_port: u16, + local_bind_host: String, + local_bind_port: u16, +} + +impl ConnectionConfiguration { + /// Creates a new ClientConfigurationData + /// + /// # Arguments + /// + /// * `server_host` - Address of the server + /// * `server_port` - Port that the server is listening on + /// * `local_bind_host` - Local address to bind to, which should usually be a wildcard address like `0.0.0.0` or `[::]`, which allow communication with any reachable IPv4 or IPv6 address. See [`quinn::endpoint::Endpoint`] for more precision + /// * `local_bind_port` - Local port to bind to. Use 0 to get an OS-assigned port.. See [`quinn::endpoint::Endpoint`] for more precision + /// + /// # Examples + /// + /// ``` + /// use bevy_quinnet::client::connection::ConnectionConfiguration; + /// let config = ConnectionConfiguration::new( + /// "127.0.0.1".to_string(), + /// 6000, + /// "0.0.0.0".to_string(), + /// 0, + /// ); + /// ``` + pub fn new( + server_host: String, + server_port: u16, + local_bind_host: String, + local_bind_port: u16, + ) -> Self { + Self { + server_host, + server_port, + local_bind_host, + local_bind_port, + } + } +} + +pub(crate) type InternalConnectionRef = QuinnConnection; + +/// Current state of a client connection +#[derive(Debug)] +pub(crate) enum ConnectionState { + Connecting, + Connected(InternalConnectionRef), + Disconnected, +} + +#[derive(Debug)] +pub struct Connection { + pub(crate) state: ConnectionState, + channels: HashMap<ChannelId, Channel>, + default_channel: Option<ChannelId>, + last_gen_id: MultiChannelId, + bytes_from_server_recv: mpsc::Receiver<Bytes>, + close_sender: broadcast::Sender<()>, + + pub(crate) from_async_client_recv: mpsc::Receiver<ClientAsyncMessage>, + pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>, + pub(crate) from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, +} + +impl Connection { + pub(crate) fn new( + bytes_from_server_recv: mpsc::Receiver<Bytes>, + close_sender: broadcast::Sender<()>, + from_async_client_recv: mpsc::Receiver<ClientAsyncMessage>, + to_channels_send: mpsc::Sender<ChannelSyncMessage>, + from_channels_recv: mpsc::Receiver<ChannelAsyncMessage>, + ) -> Self { + Self { + state: ConnectionState::Connecting, + channels: HashMap::new(), + last_gen_id: 0, + default_channel: None, + bytes_from_server_recv, + close_sender, + from_async_client_recv, + to_channels_send, + from_channels_recv, + } + } + + pub fn receive_message<T: serde::de::DeserializeOwned>( + &mut self, + ) -> Result<Option<T>, QuinnetError> { + match self.receive_payload()? { + Some(payload) => match bincode::deserialize(&payload) { + Ok(msg) => Ok(Some(msg)), + Err(_) => Err(QuinnetError::Deserialization), + }, + None => Ok(None), + } + } + + /// Same as [Connection::receive_message] but will log the error instead of returning it + pub fn try_receive_message<T: serde::de::DeserializeOwned>(&mut self) -> Option<T> { + match self.receive_message() { + Ok(message) => message, + Err(err) => { + error!("try_receive_message: {}", err); + None + } + } + } + + pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.send_message_on(channel, message), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn send_message_on<T: serde::Serialize>( + &self, + channel_id: ChannelId, + message: T, + ) -> Result<(), QuinnetError> { + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match self.channels.get(&channel_id) { + Some(channel) => match bincode::serialize(&message) { + Ok(payload) => channel.send_payload(payload), + Err(_) => Err(QuinnetError::Serialization), + }, + None => Err(QuinnetError::UnknownChannel(channel_id)), + }, + } + } + + /// Same as [Connection::send_message] but will log the error instead of returning it + pub fn try_send_message<T: serde::Serialize>(&self, message: T) { + match self.send_message(message) { + Ok(_) => {} + Err(err) => error!("try_send_message: {}", err), + } + } + + pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> { + match self.default_channel { + Some(channel) => self.send_payload_on(channel, payload), + None => Err(QuinnetError::NoDefaultChannel), + } + } + + pub fn send_payload_on<T: Into<Bytes>>( + &self, + channel_id: ChannelId, + payload: T, + ) -> Result<(), QuinnetError> { + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match self.channels.get(&channel_id) { + Some(channel) => channel.send_payload(payload), + None => Err(QuinnetError::UnknownChannel(channel_id)), + }, + } + } + + /// Same as [Connection::send_payload] but will log the error instead of returning it + pub fn try_send_payload<T: Into<Bytes>>(&self, payload: T) { + match self.send_payload(payload) { + Ok(_) => {} + Err(err) => error!("try_send_payload: {}", err), + } + } + + pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> { + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match self.bytes_from_server_recv.try_recv() { + Ok(msg_payload) => Ok(Some(msg_payload)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed), + }, + }, + } + } + + /// Same as [Connection::receive_payload] but will log the error instead of returning it + pub fn try_receive_payload(&mut self) -> Option<Bytes> { + match self.receive_payload() { + Ok(payload) => payload, + Err(err) => { + error!("try_receive_payload: {}", err); + None + } + } + } + + /// Immediately prevents new messages from being sent on the connection and signal the connection to closes all its background tasks. Before trully closing, the connection will wait for all buffered messages in all its opened channels to be properly sent according to their respective channel type. + pub(crate) fn disconnect(&mut self) -> Result<(), QuinnetError> { + match &self.state { + ConnectionState::Disconnected => Ok(()), + _ => { + self.state = ConnectionState::Disconnected; + match self.close_sender.send(()) { + Ok(_) => Ok(()), + Err(_) => { + // The only possible error for a send is that there is no active receivers, meaning that the tasks are already terminated. + Err(QuinnetError::ConnectionAlreadyClosed) + } + } + } + } + } + + pub(crate) fn try_disconnect(&mut self) { + match &self.disconnect() { + Ok(_) => (), + Err(err) => error!("Failed to properly close clonnection: {}", err), + } + } + + pub fn is_connected(&self) -> bool { + match self.state { + ConnectionState::Connected(_) => true, + _ => false, + } + } + + /// Returns statistics about the current connection if connected. + pub fn stats(&self) -> Option<ConnectionStats> { + match &self.state { + ConnectionState::Connected(connection) => Some(connection.stats()), + _ => None, + } + } + + /// Opens a channel of the requested [ChannelType] and returns its [ChannelId]. + /// + /// By default, when starting a [Connection]], Quinnet creates 1 channel instance of each [ChannelType], each with their own [ChannelId]. Among those, there is a `default` channel which will be used when you don't specify the channel. At startup, this default channel is a [ChannelType::OrderedReliable] channel. + /// + /// If no channels were previously opened, the opened channel will be the new default channel. + /// + /// Can fail if the Connection is closed. + pub fn open_channel(&mut self, channel_type: ChannelType) -> Result<ChannelId, QuinnetError> { + let channel_id = get_channel_id_from_type(channel_type, || { + self.last_gen_id += 1; + self.last_gen_id + }); + match self.channels.contains_key(&channel_id) { + true => Ok(channel_id), + false => self.create_channel(channel_id), + } + } + + /// Closes the channel with the corresponding [ChannelId]. + /// + /// No new messages will be able to be sent on this channel, however, the channel will properly try to send all the messages that were previously pushed to it, according to its [ChannelType], before fully closing. + /// + /// If the closed channel is the current default channel, the default channel gets set to `None`. + /// + /// Can fail if the [ChannelId] is unknown, or if the channel is already closed. + pub fn close_channel(&mut self, channel_id: ChannelId) -> Result<(), QuinnetError> { + match self.channels.remove(&channel_id) { + Some(channel) => { + if Some(channel_id) == self.default_channel { + self.default_channel = None; + } + channel.close() + } + None => Err(QuinnetError::UnknownChannel(channel_id)), + } + } + + /// Set the default channel + pub fn set_default_channel(&mut self, channel_id: ChannelId) { + self.default_channel = Some(channel_id); + } + + /// Get the default Channel Id + pub fn get_default_channel(&self) -> Option<ChannelId> { + self.default_channel + } + + fn create_channel(&mut self, channel_id: ChannelId) -> Result<ChannelId, QuinnetError> { + let (bytes_to_channel_send, bytes_to_channel_recv) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (channel_close_send, channel_close_recv) = + mpsc::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + match self + .to_channels_send + .try_send(ChannelSyncMessage::CreateChannel { + channel_id, + bytes_to_channel_recv, + channel_close_recv, + }) { + Ok(_) => { + let channel = Channel::new(bytes_to_channel_send, channel_close_send); + self.channels.insert(channel_id, channel); + if self.default_channel.is_none() { + self.default_channel = Some(channel_id); + } + + Ok(channel_id) + } + Err(err) => match err { + TrySendError::Full(_) => Err(QuinnetError::FullQueue), + TrySendError::Closed(_) => Err(QuinnetError::InternalChannelClosed), + }, + } + } +} + +pub(crate) async fn connection_task( + config: ConnectionConfiguration, + cert_mode: CertificateVerificationMode, + to_sync_client_send: mpsc::Sender<ClientAsyncMessage>, + to_channels_recv: mpsc::Receiver<ChannelSyncMessage>, + from_channels_send: mpsc::Sender<ChannelAsyncMessage>, + close_recv: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, +) { + let server_adr_str = format!("{}:{}", config.server_host, config.server_port); + let srv_host = config.server_host.clone(); + let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port); + + info!("Trying to connect to server on: {} ...", server_adr_str); + + let server_addr: SocketAddr = server_adr_str + .parse() + .expect("Failed to parse server address"); + + let client_cfg = configure_client(cert_mode, to_sync_client_send.clone()) + .expect("Failed to configure client"); + + let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) + .expect("Failed to create client endpoint"); + endpoint.set_default_client_config(client_cfg); + + let connection = endpoint + .connect(server_addr, &srv_host) // TODO Clean: error handling + .expect("Failed to connect: configuration error") + .await; + match connection { + Err(e) => error!("Error while connecting: {}", e), + Ok(connection) => { + info!("Connected to {}", connection.remote_address()); + + to_sync_client_send + .send(ClientAsyncMessage::Connected(connection.clone())) + .await + .expect("Failed to signal connection to sync client"); + + // Spawn a task to listen for the underlying connection being closed + { + let conn = connection.clone(); + let to_sync_client = to_sync_client_send.clone(); + tokio::spawn(async move { + let conn_err = conn.closed().await; + info!("Disconnected: {}", conn_err); + // If we requested the connection to close, channel may have been closed already. + if !to_sync_client.is_closed() { + to_sync_client + .send(ClientAsyncMessage::ConnectionClosed(conn_err)) + .await + .expect("Failed to signal connection lost to sync client"); + } + }) + }; + + // Spawn a task to listen for streams opened by the server + { + let close_recv = close_recv.resubscribe(); + let connection_handle = connection.clone(); + let bytes_incoming_send = bytes_from_server_send.clone(); + tokio::spawn(async move { + reliable_receiver_task(connection_handle, close_recv, bytes_incoming_send).await + }); + } + + // Spawn a task to listen for datagrams sent by the server + { + let close_recv = close_recv.resubscribe(); + let connection_handle = connection.clone(); + let bytes_incoming_send = bytes_from_server_send.clone(); + tokio::spawn(async move { + unreliable_receiver_task(connection_handle, close_recv, bytes_incoming_send) + .await + }); + } + + // Spawn a task to handle send channels for this connection + tokio::spawn(async move { + channels_task(connection, close_recv, to_channels_recv, from_channels_send).await + }); + } + } +} + +async fn uni_receiver_task( + recv: RecvStream, + mut close_recv: broadcast::Receiver<()>, + bytes_from_server_send: mpsc::Sender<Bytes>, +) { + tokio::select! { + _ = close_recv.recv() => { + trace!("Listener of a Unidirectional Receiving Streams received a close signal") + } + _ = async { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + // TODO Clean: error handling + bytes_from_server_send.send(msg_bytes.into()).await.unwrap(); + } + } => {} + }; +} + +async fn reliable_receiver_task( + connection: quinn::Connection, + mut close_recv: broadcast::Receiver<()>, + bytes_incoming_send: mpsc::Sender<Bytes>, +) { + let close_recv_clone = close_recv.resubscribe(); + tokio::select! { + _ = close_recv.recv() => { + trace!("Listener for new Unidirectional Receiving Streams received a close signal") + } + _ = async { + while let Ok(recv) = connection.accept_uni().await { + let bytes_from_server_send = bytes_incoming_send.clone(); + let close_recv_clone = close_recv_clone.resubscribe(); + tokio::spawn(async move { + uni_receiver_task( + recv, + close_recv_clone, + bytes_from_server_send + ).await; + }); + } + } => { + trace!("Listener for new Unidirectional Receiving Streams ended") + } + }; + trace!("All unidirectional stream receivers cleaned"); +} + +async fn unreliable_receiver_task( + connection: quinn::Connection, + mut close_recv: broadcast::Receiver<()>, + bytes_incoming_send: mpsc::Sender<Bytes>, +) { + tokio::select! { + _ = close_recv.recv() => { + trace!("Listener for unreliable datagrams received a close signal") + } + _ = async { + while let Ok(msg_bytes) = connection.read_datagram().await { + // TODO Clean: error handling + bytes_incoming_send.send(msg_bytes.into()).await.unwrap(); + } + } => { + trace!("Listener for unreliable datagrams ended") + } + }; +} + +fn configure_client( + cert_mode: CertificateVerificationMode, + to_sync_client: mpsc::Sender<ClientAsyncMessage>, +) -> Result<ClientConfig, Box<dyn Error>> { + match cert_mode { + CertificateVerificationMode::SkipVerification => { + let crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + + Ok(ClientConfig::new(Arc::new(crypto))) + } + CertificateVerificationMode::SignedByCertificateAuthority => { + Ok(ClientConfig::with_native_roots()) + } + CertificateVerificationMode::TrustOnFirstUse(config) => { + let (store, store_file) = load_known_hosts_store_from_config(config.known_hosts)?; + let crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(TofuServerVerification::new( + store, + config.verifier_behaviour, + to_sync_client, + store_file, + )) + .with_no_client_auth(); + Ok(ClientConfig::new(Arc::new(crypto))) + } + } +} @@ -20,7 +20,8 @@ mod tests { CertVerificationInfo, CertVerificationStatus, CertVerifierAction, CertificateVerificationMode, }, - Client, ConnectionConfiguration, QuinnetClientPlugin, DEFAULT_KNOWN_HOSTS_FILE, + connection::ConnectionConfiguration, + Client, QuinnetClientPlugin, DEFAULT_KNOWN_HOSTS_FILE, }, server::{ self, certificate::CertificateRetrievalMode, QuinnetServerPlugin, Server, @@ -509,7 +510,7 @@ mod tests { } fn handle_client_events( - mut connection_events: EventReader<client::ConnectionEvent>, + mut connection_events: EventReader<client::connection::ConnectionEvent>, mut cert_trust_update_events: EventReader<CertTrustUpdateEvent>, mut cert_interaction_events: EventReader<CertInteractionEvent>, mut cert_connection_abort_events: EventReader<CertConnectionAbortEvent>, diff --git a/src/shared.rs b/src/shared.rs index 5e912af..2fdbf4c 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -1,6 +1,6 @@ use std::{fmt, io, net::AddrParseError, sync::PoisonError}; -use crate::client::ConnectionId; +use crate::client::connection::ConnectionId; use bevy::prelude::{Deref, DerefMut, Resource}; use rcgen::RcgenError; use tokio::runtime::Runtime; |