aboutsummaryrefslogtreecommitdiff
path: root/src/client.rs
diff options
context:
space:
mode:
authorHenauxg <19689618+Henauxg@users.noreply.github.com>2022-11-16 21:31:52 +0100
committerHenauxg <19689618+Henauxg@users.noreply.github.com>2022-11-16 21:31:52 +0100
commita07f3ae09dd254e04a9bc42ca891b881e5b5f17b (patch)
tree6e09c7797349914075ef2b3395bbe7c286d2a12d /src/client.rs
parentee8fdbe772d7b81070a34fc9a4e722d18c3fe2af (diff)
[client] Implement multiple connections within the client resource
Diffstat (limited to 'src/client.rs')
-rw-r--r--src/client.rs146
1 files changed, 122 insertions, 24 deletions
diff --git a/src/client.rs b/src/client.rs
index c415a2b..306f010 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,4 +1,8 @@
use std::{
+ collections::{
+ hash_map::{Iter, IterMut},
+ HashMap,
+ },
error::Error,
net::SocketAddr,
sync::{Arc, Mutex},
@@ -39,7 +43,7 @@ pub mod certificate;
pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100;
pub const DEFAULT_KNOWN_HOSTS_FILE: &str = "quinnet/known_hosts";
-pub type ConnectionId = Entity;
+pub type ConnectionId = u64;
/// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage.
pub struct ConnectionEvent(ConnectionId);
@@ -124,7 +128,6 @@ pub(crate) struct ConnectionSpawnConfig {
from_server_sender: mpsc::Sender<Bytes>,
}
-#[derive(Component)]
pub struct Connection {
state: ConnectionState,
// TODO Perf: multiple channels
@@ -135,7 +138,7 @@ pub struct Connection {
}
impl Connection {
- /// Closes this connection. This does not send any message to the server, and simply closes all the connection's tasks locally.
+ /// Disconnect from the server on this connection. This does not send any message to the server, and simply closes all the connection's tasks locally.
pub fn disconnect(&mut self) -> Result<(), QuinnetError> {
if self.is_connected() {
if let Err(_) = self.close_sender.send(()) {
@@ -193,13 +196,65 @@ impl Connection {
#[derive(Resource)]
pub struct Client {
runtime: runtime::Handle,
+ connections: HashMap<ConnectionId, Connection>,
+ last_gen_id: ConnectionId,
+ default_connection_id: Option<ConnectionId>,
}
impl Client {
- /// Sapwn a connection to a server with the given [ConnectionConfiguration] and [CertificateVerificationMode]
- pub fn spawn_connection(
- &self,
- commands: &mut Commands,
+ /// Returns the default connection or None.
+ pub fn get_connection(&self) -> Option<&Connection> {
+ match self.default_connection_id {
+ Some(id) => self.connections.get(&id),
+ None => None,
+ }
+ }
+
+ /// Returns the default connection as mut or None.
+ pub fn get_connection_mut(&mut self) -> Option<&mut Connection> {
+ match self.default_connection_id {
+ Some(id) => self.connections.get_mut(&id),
+ None => None,
+ }
+ }
+
+ /// Returns the default connection. **Warning**, this function panics if there is no default connection.
+ pub fn connection(&self) -> &Connection {
+ self.connections
+ .get(&self.default_connection_id.unwrap())
+ .unwrap()
+ }
+
+ /// Returns the default connection as mut. **Warning**, this function panics if there is no default connection.
+ pub fn connection_mut(&mut self) -> &mut Connection {
+ self.connections
+ .get_mut(&self.default_connection_id.unwrap())
+ .unwrap()
+ }
+
+ /// Returns the requested connection.
+ pub fn get_connection_by_id(&self, id: ConnectionId) -> Option<&Connection> {
+ self.connections.get(&id)
+ }
+
+ /// Returns the requested connection as mut.
+ pub fn get_connection_mut_by_id(&mut self, id: ConnectionId) -> Option<&mut Connection> {
+ self.connections.get_mut(&id)
+ }
+
+ /// Returns an iterator over all connections
+ pub fn connections(&self) -> Iter<ConnectionId, Connection> {
+ self.connections.iter()
+ }
+
+ /// Returns an iterator over all connections as muts
+ pub fn connections_mut(&mut self) -> IterMut<ConnectionId, Connection> {
+ self.connections.iter_mut()
+ }
+
+ /// Open a connection to a server with the given [ConnectionConfiguration] and [CertificateVerificationMode]. The connection will raise an event when fully connected, see [ConnectionEvent]
+ pub fn open_connection(
+ &mut self,
config: ConnectionConfiguration,
cert_mode: CertificateVerificationMode,
) -> ConnectionId {
@@ -217,15 +272,13 @@ impl Client {
tokio::sync::broadcast::Receiver<()>,
) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
- let connection = commands
- .spawn(Connection {
- state: ConnectionState::Disconnected,
- sender: to_server_sender,
- receiver: from_server_receiver,
- close_sender: close_sender.clone(),
- internal_receiver: from_async_client,
- })
- .id();
+ let connection = Connection {
+ state: ConnectionState::Disconnected,
+ sender: to_server_sender,
+ receiver: from_server_receiver,
+ close_sender: close_sender.clone(),
+ internal_receiver: from_async_client,
+ };
// Async connection
self.runtime.spawn(async move {
@@ -240,7 +293,49 @@ impl Client {
})
.await
});
- connection
+
+ self.last_gen_id += 1;
+ let connection_id = self.last_gen_id;
+ self.connections.insert(connection_id, connection);
+ if self.default_connection_id.is_none() {
+ self.default_connection_id = Some(connection_id);
+ }
+
+ connection_id
+ }
+
+ /// Set the default connection
+ pub fn set_default_connection(&mut self, connection_id: ConnectionId) {
+ self.default_connection_id = Some(connection_id);
+ }
+
+ /// Close a specific connection. This will call disconnect on the connection and remove it from the client. This may fail if the [Connection] fails to disconnect or if no [Connection] if found for connection_id
+ pub fn close_connection(&mut self, connection_id: ConnectionId) -> Result<(), QuinnetError> {
+ match self.connections.remove(&connection_id) {
+ Some(mut connection) => {
+ connection.disconnect()?;
+ if let Some(default_id) = self.default_connection_id {
+ if connection_id == default_id {
+ self.default_connection_id = None;
+ }
+ }
+ Ok(())
+ }
+ None => Err(QuinnetError::UnknownConnection(connection_id)),
+ }
+ }
+
+ /// Calls close_connection on all the open connections.
+ pub fn close_all_connections(&mut self) -> Result<(), QuinnetError> {
+ for connection_id in self
+ .connections
+ .keys()
+ .cloned()
+ .collect::<Vec<ConnectionId>>()
+ {
+ self.close_connection(connection_id)?;
+ }
+ Ok(())
}
}
@@ -382,18 +477,18 @@ fn update_sync_client(
mut certificate_interaction_events: EventWriter<CertInteractionEvent>,
mut cert_trust_update_events: EventWriter<CertTrustUpdateEvent>,
mut cert_connection_abort_events: EventWriter<CertConnectionAbortEvent>,
- mut connections: Query<(&mut Connection, ConnectionId)>,
+ mut client: ResMut<Client>,
) {
- for (mut connection, connection_id) in connections.iter_mut() {
+ for (connection_id, mut connection) in &mut client.connections {
while let Ok(message) = connection.internal_receiver.try_recv() {
match message {
InternalAsyncMessage::Connected => {
connection.state = ConnectionState::Connected;
- connection_events.send(ConnectionEvent(connection_id));
+ connection_events.send(ConnectionEvent(*connection_id));
}
InternalAsyncMessage::LostConnection => {
connection.state = ConnectionState::Disconnected;
- connection_lost_events.send(ConnectionLostEvent(connection_id));
+ connection_lost_events.send(ConnectionLostEvent(*connection_id));
}
InternalAsyncMessage::CertificateInteractionRequest {
status,
@@ -401,7 +496,7 @@ fn update_sync_client(
action_sender,
} => {
certificate_interaction_events.send(CertInteractionEvent {
- connection_id,
+ connection_id: *connection_id,
status,
info,
action_sender: Mutex::new(Some(action_sender)),
@@ -409,13 +504,13 @@ fn update_sync_client(
}
InternalAsyncMessage::CertificateTrustUpdate(info) => {
cert_trust_update_events.send(CertTrustUpdateEvent {
- connection_id,
+ connection_id: *connection_id,
cert_info: info,
});
}
InternalAsyncMessage::CertificateConnectionAbort { status, cert_info } => {
cert_connection_abort_events.send(CertConnectionAbortEvent {
- connection_id,
+ connection_id: *connection_id,
status,
cert_info,
});
@@ -427,7 +522,10 @@ fn update_sync_client(
fn create_client(mut commands: Commands, runtime: Res<AsyncRuntime>) {
commands.insert_resource(Client {
+ connections: HashMap::new(),
runtime: runtime.handle().clone(),
+ last_gen_id: 0,
+ default_connection_id: None,
});
}