diff options
author | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-11-14 18:51:27 +0100 |
---|---|---|
committer | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-11-16 16:32:05 +0100 |
commit | 1f307cf640ffa6573d5191d50988176cb917af08 (patch) | |
tree | c46482f2ca2aa8e3db4fbd20b222e47b25c1871f /src | |
parent | cc1eb6bbbf1744b6d5f012a2c807992db226d971 (diff) |
[client] Start multiple connections implementation
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 251 | ||||
-rw-r--r-- | src/client/certificate.rs | 9 |
2 files changed, 137 insertions, 123 deletions
diff --git a/src/client.rs b/src/client.rs index ec9cc46..0a382bc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,6 +11,7 @@ use futures_util::StreamExt; use quinn::{ClientConfig, Endpoint}; use serde::Deserialize; use tokio::{ + runtime::{self}, sync::{ broadcast, mpsc::{ @@ -38,21 +39,23 @@ pub mod certificate; pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100; pub const DEFAULT_KNOWN_HOSTS_FILE: &str = "quinnet/known_hosts"; +pub type ConnectionId = Entity; + /// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionEvent; +pub struct ConnectionEvent(ConnectionId); /// ConnectionLost event raised when the client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage. -pub struct ConnectionLostEvent; +pub struct ConnectionLostEvent(ConnectionId); /// Configuration of the client, used when connecting to a server #[derive(Debug, Deserialize, Clone)] -pub struct ClientConfigurationData { +pub struct ConnectionConfiguration { server_host: String, server_port: u16, local_bind_host: String, local_bind_port: u16, } -impl ClientConfigurationData { +impl ConnectionConfiguration { /// Creates a new ClientConfigurationData /// /// # Arguments @@ -89,7 +92,7 @@ impl ClientConfigurationData { /// Current state of the client driver #[derive(Debug, PartialEq, Eq)] -enum ClientState { +enum ConnectionState { Disconnected, Connected, } @@ -110,42 +113,29 @@ pub(crate) enum InternalAsyncMessage { }, } -#[derive(Debug, Clone)] -pub(crate) enum InternalSyncMessage { - Connect { - config: ClientConfigurationData, - cert_mode: CertificateVerificationMode, - }, +#[derive(Debug)] +pub(crate) struct ConnectionSpawnConfig { + connection_config: ConnectionConfiguration, + cert_mode: CertificateVerificationMode, + to_sync_client: mpsc::Sender<InternalAsyncMessage>, + close_sender: tokio::sync::broadcast::Sender<()>, + close_receiver: tokio::sync::broadcast::Receiver<()>, + to_server_receiver: mpsc::Receiver<Bytes>, + from_server_sender: mpsc::Sender<Bytes>, } -#[derive(Resource)] -pub struct Client { - state: ClientState, +#[derive(Component)] +pub struct Connection { + state: ConnectionState, // TODO Perf: multiple channels sender: mpsc::Sender<Bytes>, receiver: mpsc::Receiver<Bytes>, close_sender: broadcast::Sender<()>, - pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>, - pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>, + // pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>, } -impl Client { - /// Connect to a server with the given [ClientConfigurationData] and [CertificateVerificationMode] - pub fn connect( - &self, - config: ClientConfigurationData, - cert_mode: CertificateVerificationMode, - ) -> Result<(), QuinnetError> { - match self - .internal_sender - .try_send(InternalSyncMessage::Connect { config, cert_mode }) - { - Ok(_) => Ok(()), - Err(_) => Err(QuinnetError::FullQueue), - } - } - +impl Connection { /// Disconnect the client. This does not send any message to the server, and simply closes all the connection tasks locally. pub fn disconnect(&mut self) -> Result<(), QuinnetError> { if self.is_connected() { @@ -153,7 +143,7 @@ impl Client { return Err(QuinnetError::ChannelClosed); } } - self.state = ClientState::Disconnected; + self.state = ConnectionState::Disconnected; Ok(()) } @@ -197,7 +187,63 @@ impl Client { } pub fn is_connected(&self) -> bool { - return self.state == ClientState::Connected; + return self.state == ConnectionState::Connected; + } +} + +#[derive(Resource)] +pub struct Client { + // connections: HashMap<ConnectionId, Connection>, + runtime: runtime::Handle, +} + +impl Client { + /// Connect to a server with the given [ClientConfigurationData] and [CertificateVerificationMode] + pub fn spawn_connection( + &self, + commands: &mut Commands, + config: ConnectionConfiguration, + cert_mode: CertificateVerificationMode, + ) -> ConnectionId { + let (from_server_sender, from_server_receiver) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (to_server_sender, to_server_receiver) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + + let (to_sync_client, from_async_client) = + mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + + // Create a close channel for this connection + let (close_sender, close_receiver): ( + tokio::sync::broadcast::Sender<()>, + tokio::sync::broadcast::Receiver<()>, + ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + let connection = commands + .spawn(Connection { + state: ConnectionState::Disconnected, + sender: to_server_sender, + receiver: from_server_receiver, + close_sender: close_sender.clone(), + internal_receiver: from_async_client, + // internal_sender: to_async_client, + }) + .id(); + + // Async connection + self.runtime.spawn(async move { + connection_task(ConnectionSpawnConfig { + connection_config: config, + cert_mode, + to_sync_client, + close_sender, + close_receiver, + to_server_receiver, + from_server_sender, + }) + .await + }); + connection } } @@ -233,15 +279,8 @@ fn configure_client( } } -async fn connection_task( - config: ClientConfigurationData, - cert_mode: CertificateVerificationMode, - to_sync_client: mpsc::Sender<InternalAsyncMessage>, - close_sender: tokio::sync::broadcast::Sender<()>, - mut close_receiver: tokio::sync::broadcast::Receiver<()>, - mut to_server_receiver: mpsc::Receiver<Bytes>, - from_server_sender: mpsc::Sender<Bytes>, -) { +async fn connection_task(mut spawn_config: ConnectionSpawnConfig) { + let config = spawn_config.connection_config; let server_adr_str = format!("{}:{}", config.server_host, config.server_port); let srv_host = config.server_host.clone(); let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port); @@ -252,8 +291,8 @@ async fn connection_task( .parse() .expect("Failed to parse server address"); - let client_cfg = - configure_client(cert_mode, to_sync_client.clone()).expect("Failed to configure client"); + let client_cfg = configure_client(spawn_config.cert_mode, spawn_config.to_sync_client.clone()) + .expect("Failed to configure client"); let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) .expect("Failed to create client endpoint"); @@ -271,7 +310,8 @@ async fn connection_task( new_connection.connection.remote_address() ); - to_sync_client + spawn_config + .to_sync_client .send(InternalAsyncMessage::Connected) .await .expect("Failed to signal connection to sync client"); @@ -283,21 +323,21 @@ async fn connection_task( .expect("Failed to open send stream"); let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new()); - let close_sender_clone = close_sender.clone(); + let close_sender_clone = spawn_config.close_sender.clone(); let _network_sends = tokio::spawn(async move { tokio::select! { - _ = close_receiver.recv() => { + _ = spawn_config.close_receiver.recv() => { trace!("Unidirectional send Stream forced to disconnected") } _ = async { - while let Some(msg_bytes) = to_server_receiver.recv().await { + while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await { if let Err(err) = frame_send.send(msg_bytes).await { error!("Error while sending, {}", err); // TODO Clean: error handling error!("Client seems disconnected, closing resources"); if let Err(_) = close_sender_clone.send(()) { error!("Failed to close all client streams & resources") } - to_sync_client.send( + spawn_config.to_sync_client.send( InternalAsyncMessage::LostConnection) .await .expect("Failed to signal connection lost to sync client"); @@ -310,7 +350,7 @@ async fn connection_task( }); let mut uni_receivers: JoinSet<()> = JoinSet::new(); - let mut close_receiver = close_sender.subscribe(); + let mut close_receiver = spawn_config.close_sender.subscribe(); let _network_reads = tokio::spawn(async move { tokio::select! { _ = close_receiver.recv() => { @@ -319,7 +359,7 @@ async fn connection_task( _ = async { while let Some(Ok(recv)) = new_connection.uni_streams.next().await { let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - let from_server_sender = from_server_sender.clone(); + let from_server_sender = spawn_config.from_server_sender.clone(); uni_receivers.spawn(async move { while let Some(Ok(msg_bytes)) = frame_recv.next().await { @@ -338,88 +378,57 @@ async fn connection_task( } } -fn start_async_client(mut commands: Commands, runtime: Res<AsyncRuntime>) { - let (from_server_sender, from_server_receiver) = - mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let (to_server_sender, to_server_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - - let (to_sync_client, from_async_client) = - mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_client, mut from_sync_client) = - mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - - // Create a close channel for this connection - let (close_sender, close_receiver): ( - tokio::sync::broadcast::Sender<()>, - tokio::sync::broadcast::Receiver<()>, - ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - +fn create_client(mut commands: Commands, runtime: Res<AsyncRuntime>) { commands.insert_resource(Client { - state: ClientState::Disconnected, - sender: to_server_sender, - receiver: from_server_receiver, - close_sender: close_sender.clone(), - internal_receiver: from_async_client, - internal_sender: to_async_client.clone(), - }); - - // Async client - runtime.spawn(async move { - // Wait for a connection signal before starting client - if let Some(message) = from_sync_client.recv().await { - match message { - InternalSyncMessage::Connect { config, cert_mode } => { - connection_task( - config, - cert_mode, - to_sync_client, - close_sender, - close_receiver, - to_server_receiver, - from_server_sender, - ) - .await; - } - } - } + runtime: runtime.handle().clone(), }); } // Receive messages from the async client tasks and update the sync client. fn update_sync_client( - mut client: ResMut<Client>, mut connection_events: EventWriter<ConnectionEvent>, mut connection_lost_events: EventWriter<ConnectionLostEvent>, mut certificate_interaction_events: EventWriter<CertInteractionEvent>, mut cert_trust_update_events: EventWriter<CertTrustUpdateEvent>, mut cert_connection_abort_events: EventWriter<CertConnectionAbortEvent>, + mut connections: Query<(&mut Connection, ConnectionId)>, ) { - while let Ok(message) = client.internal_receiver.try_recv() { - match message { - InternalAsyncMessage::Connected => { - client.state = ClientState::Connected; - connection_events.send(ConnectionEvent); - } - InternalAsyncMessage::LostConnection => { - client.state = ClientState::Disconnected; - connection_lost_events.send(ConnectionLostEvent); - } - InternalAsyncMessage::CertificateInteractionRequest { - status, - info, - action_sender, - } => { - certificate_interaction_events.send(CertInteractionEvent { + for (mut connection, connection_id) in connections.iter_mut() { + while let Ok(message) = connection.internal_receiver.try_recv() { + match message { + InternalAsyncMessage::Connected => { + connection.state = ConnectionState::Connected; + connection_events.send(ConnectionEvent(connection_id)); + } + InternalAsyncMessage::LostConnection => { + connection.state = ConnectionState::Disconnected; + connection_lost_events.send(ConnectionLostEvent(connection_id)); + } + InternalAsyncMessage::CertificateInteractionRequest { status, info, - action_sender: Mutex::new(Some(action_sender)), - }); - } - InternalAsyncMessage::CertificateTrustUpdate(info) => { - cert_trust_update_events.send(CertTrustUpdateEvent(info)); - } - InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => { - cert_connection_abort_events.send(CertConnectionAbortEvent { status, cert_info }); + action_sender, + } => { + certificate_interaction_events.send(CertInteractionEvent { + connection_id, + status, + info, + action_sender: Mutex::new(Some(action_sender)), + }); + } + InternalAsyncMessage::CertificateTrustUpdate(info) => { + cert_trust_update_events.send(CertTrustUpdateEvent { + connection_id, + cert_info: info, + }); + } + InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => { + cert_connection_abort_events.send(CertConnectionAbortEvent { + connection_id, + status, + cert_info, + }); + } } } } @@ -441,7 +450,7 @@ impl Plugin for QuinnetClientPlugin { .add_event::<CertTrustUpdateEvent>() .add_event::<CertConnectionAbortEvent>() // StartupStage::PreStartup so that resources created in commands are available to default startup_systems - .add_startup_system_to_stage(StartupStage::PreStartup, start_async_client) + .add_startup_system_to_stage(StartupStage::PreStartup, create_client) .add_system(update_sync_client); if app.world.get_resource_mut::<AsyncRuntime>().is_none() { diff --git a/src/client/certificate.rs b/src/client/certificate.rs index 94f2abe..d773f5e 100644 --- a/src/client/certificate.rs +++ b/src/client/certificate.rs @@ -15,13 +15,14 @@ use tokio::sync::{mpsc, oneshot}; use crate::QuinnetError; -use super::{InternalAsyncMessage, DEFAULT_KNOWN_HOSTS_FILE}; +use super::{ConnectionId, InternalAsyncMessage, DEFAULT_KNOWN_HOSTS_FILE}; pub const DEFAULT_CERT_VERIFIER_BEHAVIOUR: CertVerifierBehaviour = CertVerifierBehaviour::ImmediateAction(CertVerifierAction::AbortConnection); /// Event raised when a user/app interaction is needed for the server's certificate validation pub struct CertInteractionEvent { + pub connection_id: ConnectionId, /// The current status of the verification pub status: CertVerificationStatus, /// Server & Certificate info @@ -48,10 +49,14 @@ impl CertInteractionEvent { } /// Event raised when a new certificate is trusted -pub struct CertTrustUpdateEvent(pub CertVerificationInfo); +pub struct CertTrustUpdateEvent { + pub connection_id: ConnectionId, + pub cert_info: CertVerificationInfo, +} /// Event raised when a connection is aborted during the certificate verification pub struct CertConnectionAbortEvent { + pub connection_id: ConnectionId, pub status: CertVerificationStatus, pub cert_info: CertVerificationInfo, } |