aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/client.rs255
-rw-r--r--src/client/certificate.rs4
2 files changed, 134 insertions, 125 deletions
diff --git a/src/client.rs b/src/client.rs
index 201a795..553f7ae 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -9,10 +9,9 @@ use bytes::Bytes;
use futures::sink::SinkExt;
use futures_util::StreamExt;
use quinn::{ClientConfig, Endpoint};
-use rustls::Certificate;
use serde::Deserialize;
use tokio::{
- runtime::Runtime,
+ runtime::{self, Runtime},
sync::{
broadcast,
mpsc::{
@@ -38,21 +37,23 @@ pub mod certificate;
pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100;
pub const DEFAULT_KNOWN_HOSTS_FILE: &str = "quinnet/known_hosts";
+pub type ConnectionId = Entity;
+
/// Connection event raised when the client just connected to the server. Raised in the CoreStage::PreUpdate stage.
-pub struct ConnectionEvent;
+pub struct ConnectionEvent(ConnectionId);
/// ConnectionLost event raised when the client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage.
-pub struct ConnectionLostEvent;
+pub struct ConnectionLostEvent(ConnectionId);
/// Configuration of the client, used when connecting to a server
#[derive(Debug, Deserialize, Clone)]
-pub struct ClientConfigurationData {
+pub struct ConnectionConfiguration {
server_host: String,
server_port: u16,
local_bind_host: String,
local_bind_port: u16,
}
-impl ClientConfigurationData {
+impl ConnectionConfiguration {
/// Creates a new ClientConfigurationData
///
/// # Arguments
@@ -89,14 +90,14 @@ impl ClientConfigurationData {
/// Current state of the client driver
#[derive(Debug, PartialEq, Eq)]
-enum ClientState {
+enum ConnectionState {
Disconnected,
Connected,
}
#[derive(Debug)]
pub(crate) enum InternalAsyncMessage {
- Connected(Option<Vec<Certificate>>),
+ Connected,
LostConnection,
CertificateActionRequest {
status: CertVerificationStatus,
@@ -108,41 +109,34 @@ pub(crate) enum InternalAsyncMessage {
},
}
-#[derive(Debug, Clone)]
-pub(crate) enum InternalSyncMessage {
- Connect {
- config: ClientConfigurationData,
- cert_mode: CertificateVerificationMode,
- },
+#[derive(Debug)]
+pub(crate) struct ConnectionSpawnConfig {
+ connection_config: ConnectionConfiguration,
+ cert_mode: CertificateVerificationMode,
+ to_sync_client: mpsc::Sender<InternalAsyncMessage>,
+ close_sender: tokio::sync::broadcast::Sender<()>,
+ close_receiver: tokio::sync::broadcast::Receiver<()>,
+ to_server_receiver: mpsc::Receiver<Bytes>,
+ from_server_sender: mpsc::Sender<Bytes>,
}
-pub struct Client {
- state: ClientState,
+// #[derive(Debug)]
+// pub(crate) enum InternalSyncMessage {
+// // SpawnConnection(ConnectionSpawnConfig),
+// }
+
+#[derive(Component)]
+pub struct Connection {
+ state: ConnectionState,
// TODO Perf: multiple channels
sender: mpsc::Sender<Bytes>,
receiver: mpsc::Receiver<Bytes>,
close_sender: broadcast::Sender<()>,
-
pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>,
- pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>,
+ // pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>,
}
-impl Client {
- /// Connect to a server with the given [ClientConfigurationData] and [CertificateVerificationMode]
- pub fn connect(
- &self,
- config: ClientConfigurationData,
- cert_mode: CertificateVerificationMode,
- ) -> Result<(), QuinnetError> {
- match self
- .internal_sender
- .try_send(InternalSyncMessage::Connect { config, cert_mode })
- {
- Ok(_) => Ok(()),
- Err(_) => Err(QuinnetError::FullQueue),
- }
- }
-
+impl Connection {
/// Disconnect the client. This does not send any message to the server, and simply closes all the connection tasks locally.
pub fn disconnect(&mut self) -> Result<(), QuinnetError> {
if self.is_connected() {
@@ -150,7 +144,7 @@ impl Client {
return Err(QuinnetError::ChannelClosed);
}
}
- self.state = ClientState::Disconnected;
+ self.state = ConnectionState::Disconnected;
Ok(())
}
@@ -194,7 +188,63 @@ impl Client {
}
pub fn is_connected(&self) -> bool {
- return self.state == ClientState::Connected;
+ return self.state == ConnectionState::Connected;
+ }
+}
+
+pub struct Client {
+ // connections: HashMap<ConnectionId, Connection>,
+ runtime: runtime::Handle,
+}
+
+impl Client {
+ /// Connect to a server with the given [ClientConfigurationData] and [CertificateVerificationMode]
+ pub fn spawn_connection(
+ &self,
+ commands: &mut Commands,
+ config: ConnectionConfiguration,
+ cert_mode: CertificateVerificationMode,
+ ) -> ConnectionId {
+ let (from_server_sender, from_server_receiver) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let (to_server_sender, to_server_receiver) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+
+ let (to_sync_client, from_async_client) =
+ mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+
+ // Create a close channel for this connection
+ let (close_sender, close_receiver): (
+ tokio::sync::broadcast::Sender<()>,
+ tokio::sync::broadcast::Receiver<()>,
+ ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
+
+ let connection = commands
+ .spawn()
+ .insert(Connection {
+ state: ConnectionState::Disconnected,
+ sender: to_server_sender,
+ receiver: from_server_receiver,
+ close_sender: close_sender.clone(),
+ internal_receiver: from_async_client,
+ // internal_sender: to_async_client,
+ })
+ .id();
+
+ // Async connection
+ self.runtime.spawn(async move {
+ connection_task(ConnectionSpawnConfig {
+ connection_config: config,
+ cert_mode,
+ to_sync_client,
+ close_sender,
+ close_receiver,
+ to_server_receiver,
+ from_server_sender,
+ })
+ .await
+ });
+ connection
}
}
@@ -230,15 +280,8 @@ fn configure_client(
}
}
-async fn connection_task(
- config: ClientConfigurationData,
- cert_mode: CertificateVerificationMode,
- to_sync_client: mpsc::Sender<InternalAsyncMessage>,
- close_sender: tokio::sync::broadcast::Sender<()>,
- mut close_receiver: tokio::sync::broadcast::Receiver<()>,
- mut to_server_receiver: mpsc::Receiver<Bytes>,
- from_server_sender: mpsc::Sender<Bytes>,
-) {
+async fn connection_task(mut spawn_config: ConnectionSpawnConfig) {
+ let config = spawn_config.connection_config;
let server_adr_str = format!("{}:{}", config.server_host, config.server_port);
let srv_host = config.server_host.clone();
let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port);
@@ -249,8 +292,8 @@ async fn connection_task(
.parse()
.expect("Failed to parse server address");
- let client_cfg =
- configure_client(cert_mode, to_sync_client.clone()).expect("Failed to configure client");
+ let client_cfg = configure_client(spawn_config.cert_mode, spawn_config.to_sync_client.clone())
+ .expect("Failed to configure client");
let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap())
.expect("Failed to create client endpoint");
@@ -268,8 +311,9 @@ async fn connection_task(
new_connection.connection.remote_address()
);
- to_sync_client
- .send(InternalAsyncMessage::Connected(None))
+ spawn_config
+ .to_sync_client
+ .send(InternalAsyncMessage::Connected)
.await
.expect("Failed to signal connection to sync client");
@@ -280,21 +324,21 @@ async fn connection_task(
.expect("Failed to open send stream");
let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new());
- let close_sender_clone = close_sender.clone();
+ let close_sender_clone = spawn_config.close_sender.clone();
let _network_sends = tokio::spawn(async move {
tokio::select! {
- _ = close_receiver.recv() => {
+ _ = spawn_config.close_receiver.recv() => {
trace!("Unidirectional send Stream forced to disconnected")
}
_ = async {
- while let Some(msg_bytes) = to_server_receiver.recv().await {
+ while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await {
if let Err(err) = frame_send.send(msg_bytes).await {
error!("Error while sending, {}", err); // TODO Clean: error handling
error!("Client seems disconnected, closing resources");
if let Err(_) = close_sender_clone.send(()) {
error!("Failed to close all client streams & resources")
}
- to_sync_client.send(
+ spawn_config.to_sync_client.send(
InternalAsyncMessage::LostConnection)
.await
.expect("Failed to signal connection lost to sync client");
@@ -307,7 +351,7 @@ async fn connection_task(
});
let mut uni_receivers: JoinSet<()> = JoinSet::new();
- let mut close_receiver = close_sender.subscribe();
+ let mut close_receiver = spawn_config.close_sender.subscribe();
let _network_reads = tokio::spawn(async move {
tokio::select! {
_ = close_receiver.recv() => {
@@ -316,7 +360,7 @@ async fn connection_task(
_ = async {
while let Some(Ok(recv)) = new_connection.uni_streams.next().await {
let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
- let from_server_sender = from_server_sender.clone();
+ let from_server_sender = spawn_config.from_server_sender.clone();
uni_receivers.spawn(async move {
while let Some(Ok(msg_bytes)) = frame_recv.next().await {
@@ -335,88 +379,51 @@ async fn connection_task(
}
}
-fn start_async_client(mut commands: Commands, runtime: Res<Runtime>) {
- let (from_server_sender, from_server_receiver) =
- mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
- let (to_server_sender, to_server_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
-
- let (to_sync_client, from_async_client) =
- mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
- let (to_async_client, mut from_sync_client) =
- mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
-
- // Create a close channel for this connection
- let (close_sender, close_receiver): (
- tokio::sync::broadcast::Sender<()>,
- tokio::sync::broadcast::Receiver<()>,
- ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
-
+fn create_client(mut commands: Commands, runtime: Res<Runtime>) {
commands.insert_resource(Client {
- state: ClientState::Disconnected,
- sender: to_server_sender,
- receiver: from_server_receiver,
- close_sender: close_sender.clone(),
- internal_receiver: from_async_client,
- internal_sender: to_async_client.clone(),
- });
-
- // Async client
- runtime.spawn(async move {
- // Wait for a connection signal before starting client
- if let Some(message) = from_sync_client.recv().await {
- match message {
- InternalSyncMessage::Connect { config, cert_mode } => {
- connection_task(
- config,
- cert_mode,
- to_sync_client,
- close_sender,
- close_receiver,
- to_server_receiver,
- from_server_sender,
- )
- .await;
- }
- }
- }
+ runtime: runtime.handle().clone(),
});
}
// Receive messages from the async client tasks and update the sync client.
fn update_sync_client(
- mut client: ResMut<Client>,
mut connection_events: EventWriter<ConnectionEvent>,
mut connection_lost_events: EventWriter<ConnectionLostEvent>,
mut certificate_interaction_events: EventWriter<CertificateInteractionEvent>,
mut certificate_update_events: EventWriter<CertificateUpdateEvent>,
+ mut connections: Query<(&mut Connection, ConnectionId)>,
) {
- while let Ok(message) = client.internal_receiver.try_recv() {
- match message {
- InternalAsyncMessage::Connected(_) => {
- client.state = ClientState::Connected;
- connection_events.send(ConnectionEvent);
- }
- InternalAsyncMessage::LostConnection => {
- client.state = ClientState::Disconnected;
- connection_lost_events.send(ConnectionLostEvent);
- }
- InternalAsyncMessage::CertificateActionRequest {
- status,
- action_sender,
- } => {
- certificate_interaction_events.send(CertificateInteractionEvent {
+ for (mut connection, connection_id) in connections.iter_mut() {
+ while let Ok(message) = connection.internal_receiver.try_recv() {
+ match message {
+ InternalAsyncMessage::Connected => {
+ connection.state = ConnectionState::Connected;
+ connection_events.send(ConnectionEvent(connection_id));
+ }
+ InternalAsyncMessage::LostConnection => {
+ connection.state = ConnectionState::Disconnected;
+ connection_lost_events.send(ConnectionLostEvent(connection_id));
+ }
+ InternalAsyncMessage::CertificateActionRequest {
status,
- action_sender: Mutex::new(Some(action_sender)),
- });
- }
- InternalAsyncMessage::TrustedCertificateUpdate {
- server_name,
- fingerprint,
- } => {
- certificate_update_events.send(CertificateUpdateEvent {
+ action_sender,
+ } => {
+ certificate_interaction_events.send(CertificateInteractionEvent {
+ connection_id,
+ status,
+ action_sender: Mutex::new(Some(action_sender)),
+ });
+ }
+ InternalAsyncMessage::TrustedCertificateUpdate {
server_name,
fingerprint,
- });
+ } => {
+ certificate_update_events.send(CertificateUpdateEvent {
+ connection_id,
+ server_name,
+ fingerprint,
+ });
+ }
}
}
}
@@ -437,7 +444,7 @@ impl Plugin for QuinnetClientPlugin {
.add_event::<CertificateInteractionEvent>()
.add_event::<CertificateUpdateEvent>()
// StartupStage::PreStartup so that resources created in commands are available to default startup_systems
- .add_startup_system_to_stage(StartupStage::PreStartup, start_async_client)
+ .add_startup_system_to_stage(StartupStage::PreStartup, create_client)
.add_system(update_sync_client);
if app.world.get_resource_mut::<Runtime>().is_none() {
diff --git a/src/client/certificate.rs b/src/client/certificate.rs
index 44838ff..f3ddde2 100644
--- a/src/client/certificate.rs
+++ b/src/client/certificate.rs
@@ -15,13 +15,14 @@ use tokio::sync::{mpsc, oneshot};
use crate::QuinnetError;
-use super::{InternalAsyncMessage, DEFAULT_KNOWN_HOSTS_FILE};
+use super::{ConnectionId, InternalAsyncMessage, DEFAULT_KNOWN_HOSTS_FILE};
pub const DEFAULT_CERT_VERIFIER_BEHAVIOUR: CertVerifierBehaviour =
CertVerifierBehaviour::ImmediateAction(CertVerifierAction::AbortConnection);
/// Event raised when a user/app interaction is needed for the server's certificate validation
pub struct CertificateInteractionEvent {
+ pub connection_id: ConnectionId,
/// The current status of the verification
pub status: CertVerificationStatus,
/// Mutex for interior mutability
@@ -39,6 +40,7 @@ impl CertificateInteractionEvent {
/// Event raised when a new certificate is trusted
pub struct CertificateUpdateEvent {
+ pub connection_id: ConnectionId,
/// Identifies the server name
pub server_name: ServerName,
/// Fingerprint of the server's certificate