diff options
-rw-r--r-- | src/client.rs | 146 | ||||
-rw-r--r-- | src/lib.rs | 3 |
2 files changed, 125 insertions, 24 deletions
diff --git a/src/client.rs b/src/client.rs index c415a2b..306f010 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,8 @@ use std::{ + collections::{ + hash_map::{Iter, IterMut}, + HashMap, + }, error::Error, net::SocketAddr, sync::{Arc, Mutex}, @@ -39,7 +43,7 @@ pub mod certificate; pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100; pub const DEFAULT_KNOWN_HOSTS_FILE: &str = "quinnet/known_hosts"; -pub type ConnectionId = Entity; +pub type ConnectionId = u64; /// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage. pub struct ConnectionEvent(ConnectionId); @@ -124,7 +128,6 @@ pub(crate) struct ConnectionSpawnConfig { from_server_sender: mpsc::Sender<Bytes>, } -#[derive(Component)] pub struct Connection { state: ConnectionState, // TODO Perf: multiple channels @@ -135,7 +138,7 @@ pub struct Connection { } impl Connection { - /// Closes this connection. This does not send any message to the server, and simply closes all the connection's tasks locally. + /// Disconnect from the server on this connection. This does not send any message to the server, and simply closes all the connection's tasks locally. pub fn disconnect(&mut self) -> Result<(), QuinnetError> { if self.is_connected() { if let Err(_) = self.close_sender.send(()) { @@ -193,13 +196,65 @@ impl Connection { #[derive(Resource)] pub struct Client { runtime: runtime::Handle, + connections: HashMap<ConnectionId, Connection>, + last_gen_id: ConnectionId, + default_connection_id: Option<ConnectionId>, } impl Client { - /// Sapwn a connection to a server with the given [ConnectionConfiguration] and [CertificateVerificationMode] - pub fn spawn_connection( - &self, - commands: &mut Commands, + /// Returns the default connection or None. + pub fn get_connection(&self) -> Option<&Connection> { + match self.default_connection_id { + Some(id) => self.connections.get(&id), + None => None, + } + } + + /// Returns the default connection as mut or None. + pub fn get_connection_mut(&mut self) -> Option<&mut Connection> { + match self.default_connection_id { + Some(id) => self.connections.get_mut(&id), + None => None, + } + } + + /// Returns the default connection. **Warning**, this function panics if there is no default connection. + pub fn connection(&self) -> &Connection { + self.connections + .get(&self.default_connection_id.unwrap()) + .unwrap() + } + + /// Returns the default connection as mut. **Warning**, this function panics if there is no default connection. + pub fn connection_mut(&mut self) -> &mut Connection { + self.connections + .get_mut(&self.default_connection_id.unwrap()) + .unwrap() + } + + /// Returns the requested connection. + pub fn get_connection_by_id(&self, id: ConnectionId) -> Option<&Connection> { + self.connections.get(&id) + } + + /// Returns the requested connection as mut. + pub fn get_connection_mut_by_id(&mut self, id: ConnectionId) -> Option<&mut Connection> { + self.connections.get_mut(&id) + } + + /// Returns an iterator over all connections + pub fn connections(&self) -> Iter<ConnectionId, Connection> { + self.connections.iter() + } + + /// Returns an iterator over all connections as muts + pub fn connections_mut(&mut self) -> IterMut<ConnectionId, Connection> { + self.connections.iter_mut() + } + + /// Open a connection to a server with the given [ConnectionConfiguration] and [CertificateVerificationMode]. The connection will raise an event when fully connected, see [ConnectionEvent] + pub fn open_connection( + &mut self, config: ConnectionConfiguration, cert_mode: CertificateVerificationMode, ) -> ConnectionId { @@ -217,15 +272,13 @@ impl Client { tokio::sync::broadcast::Receiver<()>, ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); - let connection = commands - .spawn(Connection { - state: ConnectionState::Disconnected, - sender: to_server_sender, - receiver: from_server_receiver, - close_sender: close_sender.clone(), - internal_receiver: from_async_client, - }) - .id(); + let connection = Connection { + state: ConnectionState::Disconnected, + sender: to_server_sender, + receiver: from_server_receiver, + close_sender: close_sender.clone(), + internal_receiver: from_async_client, + }; // Async connection self.runtime.spawn(async move { @@ -240,7 +293,49 @@ impl Client { }) .await }); - connection + + self.last_gen_id += 1; + let connection_id = self.last_gen_id; + self.connections.insert(connection_id, connection); + if self.default_connection_id.is_none() { + self.default_connection_id = Some(connection_id); + } + + connection_id + } + + /// Set the default connection + pub fn set_default_connection(&mut self, connection_id: ConnectionId) { + self.default_connection_id = Some(connection_id); + } + + /// Close a specific connection. This will call disconnect on the connection and remove it from the client. This may fail if the [Connection] fails to disconnect or if no [Connection] if found for connection_id + pub fn close_connection(&mut self, connection_id: ConnectionId) -> Result<(), QuinnetError> { + match self.connections.remove(&connection_id) { + Some(mut connection) => { + connection.disconnect()?; + if let Some(default_id) = self.default_connection_id { + if connection_id == default_id { + self.default_connection_id = None; + } + } + Ok(()) + } + None => Err(QuinnetError::UnknownConnection(connection_id)), + } + } + + /// Calls close_connection on all the open connections. + pub fn close_all_connections(&mut self) -> Result<(), QuinnetError> { + for connection_id in self + .connections + .keys() + .cloned() + .collect::<Vec<ConnectionId>>() + { + self.close_connection(connection_id)?; + } + Ok(()) } } @@ -382,18 +477,18 @@ fn update_sync_client( mut certificate_interaction_events: EventWriter<CertInteractionEvent>, mut cert_trust_update_events: EventWriter<CertTrustUpdateEvent>, mut cert_connection_abort_events: EventWriter<CertConnectionAbortEvent>, - mut connections: Query<(&mut Connection, ConnectionId)>, + mut client: ResMut<Client>, ) { - for (mut connection, connection_id) in connections.iter_mut() { + for (connection_id, mut connection) in &mut client.connections { while let Ok(message) = connection.internal_receiver.try_recv() { match message { InternalAsyncMessage::Connected => { connection.state = ConnectionState::Connected; - connection_events.send(ConnectionEvent(connection_id)); + connection_events.send(ConnectionEvent(*connection_id)); } InternalAsyncMessage::LostConnection => { connection.state = ConnectionState::Disconnected; - connection_lost_events.send(ConnectionLostEvent(connection_id)); + connection_lost_events.send(ConnectionLostEvent(*connection_id)); } InternalAsyncMessage::CertificateInteractionRequest { status, @@ -401,7 +496,7 @@ fn update_sync_client( action_sender, } => { certificate_interaction_events.send(CertInteractionEvent { - connection_id, + connection_id: *connection_id, status, info, action_sender: Mutex::new(Some(action_sender)), @@ -409,13 +504,13 @@ fn update_sync_client( } InternalAsyncMessage::CertificateTrustUpdate(info) => { cert_trust_update_events.send(CertTrustUpdateEvent { - connection_id, + connection_id: *connection_id, cert_info: info, }); } InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => { cert_connection_abort_events.send(CertConnectionAbortEvent { - connection_id, + connection_id: *connection_id, status, cert_info, }); @@ -427,7 +522,10 @@ fn update_sync_client( fn create_client(mut commands: Commands, runtime: Res<AsyncRuntime>) { commands.insert_resource(Client { + connections: HashMap::new(), runtime: runtime.handle().clone(), + last_gen_id: 0, + default_connection_id: None, }); } @@ -1,6 +1,7 @@ use std::sync::PoisonError; use bevy::prelude::{Deref, DerefMut, Resource}; +use client::ConnectionId; use tokio::runtime::Runtime; pub const DEFAULT_MESSAGE_QUEUE_SIZE: usize = 150; @@ -20,6 +21,8 @@ pub(crate) struct AsyncRuntime(pub(crate) Runtime); pub enum QuinnetError { #[error("Client with id `{0}` is unknown")] UnknownClient(ClientId), + #[error("Connection with id `{0}` is unknown")] + UnknownConnection(ConnectionId), #[error("Failed serialization")] Serialization, #[error("Failed deserialization")] |