aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenauxg <19689618+Henauxg@users.noreply.github.com>2022-10-26 18:00:11 +0200
committerHenauxg <19689618+Henauxg@users.noreply.github.com>2022-10-26 18:00:11 +0200
commite85b70b4cd38d27c810385e1a79fdf98016f5f88 (patch)
treeedbceac2ba8b25994772af4f3968b6c0a8a0ee82
parent99838c7e719057aa42b68acad2290eb57bebe78e (diff)
[server] Add CertificateRetrievalMode and start() method to start listening
-rw-r--r--src/server.rs434
1 files changed, 251 insertions, 183 deletions
diff --git a/src/server.rs b/src/server.rs
index aeb91d1..c6de319 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,42 +1,38 @@
-use std::{
- collections::{HashMap},
- error::Error,
- net::SocketAddr,
- sync::{Arc},
- time::Duration,
-};
+use std::{collections::HashMap, error::Error, net::SocketAddr, sync::Arc, time::Duration};
use bevy::prelude::*;
use bytes::Bytes;
-use futures::{
- sink::SinkExt,
-};
+use futures::sink::SinkExt;
use futures_util::StreamExt;
use quinn::{Endpoint, NewConnection, ServerConfig};
+use rustls::Certificate;
use serde::Deserialize;
use tokio::{
runtime::Runtime,
sync::{
broadcast::{self},
- mpsc::{
- self,
- error::{TryRecvError},
- },
+ mpsc::{self, error::TryRecvError},
},
- task::{ JoinSet},
+ task::JoinSet,
};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
-use crate::{ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_MESSAGE_QUEUE_SIZE, DEFAULT_KILL_MESSAGE_QUEUE_SIZE};
+use crate::{
+ ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_KILL_MESSAGE_QUEUE_SIZE,
+ DEFAULT_MESSAGE_QUEUE_SIZE,
+};
pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 100;
-
/// Connection event raised when a client just connected to the server. Raised in the CoreStage::PreUpdate stage.
-pub struct ConnectionEvent { pub id: ClientId }
+pub struct ConnectionEvent {
+ pub id: ClientId,
+}
/// ConnectionLost event raised when a client is considered disconnected from the server. Raised in the CoreStage::PreUpdate stage.
-pub struct ConnectionLostEvent { pub id: ClientId }
+pub struct ConnectionLostEvent {
+ pub id: ClientId,
+}
#[derive(Debug)]
pub(crate) enum InternalAsyncMessage {
@@ -46,10 +42,14 @@ pub(crate) enum InternalAsyncMessage {
#[derive(Debug, Clone)]
pub(crate) enum InternalSyncMessage {
+ StartListening {
+ config: ServerConfigurationData,
+ cert_mode: CertificateRetrievalMode,
+ },
ClientConnectedAck(ClientId),
}
-#[derive(Deserialize)]
+#[derive(Debug, Deserialize, Clone)]
pub struct ServerConfigurationData {
host: String,
port: u16,
@@ -57,11 +57,7 @@ pub struct ServerConfigurationData {
}
impl ServerConfigurationData {
- pub fn new(
- host: String,
- port: u16,
- local_bind_host: String,
- ) -> Self {
+ pub fn new(host: String, port: u16, local_bind_host: String) -> Self {
Self {
host,
port,
@@ -70,7 +66,6 @@ impl ServerConfigurationData {
}
}
-
#[derive(Debug)]
pub struct ClientPayload {
client_id: ClientId,
@@ -79,7 +74,7 @@ pub struct ClientPayload {
#[derive(Debug)]
pub(crate) struct ClientConnection {
- client_id: ClientId,
+ client_id: ClientId,
sender: mpsc::Sender<Bytes>,
close_sender: broadcast::Sender<()>,
}
@@ -92,15 +87,45 @@ pub struct Server {
pub(crate) internal_sender: broadcast::Sender<InternalSyncMessage>,
}
+/// How the server should retrieve its certificate.
+#[derive(Debug, Clone)]
+pub enum CertificateRetrievalMode {
+ Provided(Certificate),
+ GenerateSelfSigned,
+ LoadCertFromFile(String),
+ LoadCertFromFileOrGenerateSelfSigned(String),
+}
+
impl Server {
+ /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode]
+ pub fn start(
+ &self,
+ config: ServerConfigurationData,
+ cert_mode: CertificateRetrievalMode,
+ ) -> Result<(), QuinnetError> {
+ match self
+ .internal_sender
+ .send(InternalSyncMessage::StartListening { config, cert_mode })
+ {
+ Ok(_) => Ok(()),
+ Err(_) => Err(QuinnetError::FullQueue),
+ }
+ }
+
pub fn disconnect_client(&mut self, client_id: ClientId) {
match self.clients.remove(&client_id) {
Some(client_connection) => {
if let Err(_) = client_connection.close_sender.send(()) {
- error!("Failed to close client streams & resources while disconnecting client {}", client_id)
+ error!(
+ "Failed to close client streams & resources while disconnecting client {}",
+ client_id
+ )
}
- },
- None => error!("Failed to disconnect client {}, client not found", client_id),
+ }
+ None => error!(
+ "Failed to disconnect client {}, client not found",
+ client_id
+ ),
}
}
@@ -127,7 +152,7 @@ impl Server {
}
}
- pub fn send_group_message<'a, I: Iterator<Item=&'a ClientId>, T: serde::Serialize>(
+ pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>(
&mut self,
client_ids: I,
message: T,
@@ -139,7 +164,7 @@ impl Server {
self.send_payload(*id, payload.clone());
}
Ok(())
- },
+ }
Err(_) => Err(QuinnetError::Serialization),
}
}
@@ -157,7 +182,10 @@ impl Server {
pub fn broadcast_payload<T: Into<Bytes> + Clone>(&mut self, payload: T) {
// TODO Fix: Error handling
for (_, client_connection) in self.clients.iter() {
- client_connection.sender.try_send(payload.clone().into()).unwrap();
+ client_connection
+ .sender
+ .try_send(payload.clone().into())
+ .unwrap();
}
}
@@ -183,25 +211,36 @@ impl Server {
}
/// Returns default server configuration along with its certificate.
-#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527
-fn configure_server(server_host: &String) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
- let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap();
- let cert_der = cert.serialize_der().unwrap();
- let priv_key = rustls::PrivateKey(cert.serialize_private_key_der());
- let cert_chain = vec![rustls::Certificate(cert_der.clone())];
-
- let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
- Arc::get_mut(&mut server_config.transport)
- .unwrap()
- .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
-
- Ok((server_config, cert_der))
+fn configure_server(
+ server_host: &String,
+ cert_mode: CertificateRetrievalMode,
+) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
+ match cert_mode {
+ CertificateRetrievalMode::Provided(_cert) => todo!(),
+ CertificateRetrievalMode::GenerateSelfSigned => {
+ let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap();
+ let cert_der = cert.serialize_der().unwrap();
+ let priv_key = rustls::PrivateKey(cert.serialize_private_key_der());
+ let cert_chain = vec![rustls::Certificate(cert_der.clone())];
+
+ let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
+ Arc::get_mut(&mut server_config.transport)
+ .unwrap()
+ .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
+
+ Ok((server_config, cert_der))
+ }
+ CertificateRetrievalMode::LoadCertFromFile(_) => todo!(),
+ CertificateRetrievalMode::LoadCertFromFileOrGenerateSelfSigned(_) => todo!(),
+ }
}
-fn start_server(
- mut commands: Commands,
- runtime: Res<Runtime>,
- config: Res<ServerConfigurationData>,
+async fn connections_listening_task(
+ config: ServerConfigurationData,
+ cert_mode: CertificateRetrievalMode,
+ to_sync_server: mpsc::Sender<InternalAsyncMessage>,
+ mut from_sync_server: broadcast::Receiver<InternalSyncMessage>,
+ from_clients_sender: mpsc::Sender<ClientPayload>,
) {
// TODO Fix: handle unwraps
let server_adr_str = format!("{}:{}", config.local_bind_host, config.port);
@@ -211,153 +250,179 @@ fn start_server(
.parse()
.expect("Failed to parse server address");
- // TODO Security: Server certificate
let (server_config, _server_cert) =
- configure_server(&config.host).expect("Failed to configure server");
+ configure_server(&config.host, cert_mode).expect("Failed to configure server");
+
+ let mut client_gen_id: ClientId = 0;
+ let mut client_id_mappings = HashMap::new();
+
+ let (_endpoint, mut incoming) =
+ Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint");
+
+ // Start iterating over incoming connections/clients.
+ while let Some(conn) = incoming.next().await {
+ let mut new_connection: NewConnection =
+ conn.await.expect("Failed to handle incoming connection");
+ let connection = new_connection.connection;
+
+ // Attribute an id to this client
+ client_gen_id += 1; // TODO Fix: Better id generation/check
+ let client_id = client_gen_id;
+ client_id_mappings.insert(connection.stable_id(), client_id);
+
+ info!(
+ "New connection from {}, client_id: {}, stable_id : {}",
+ connection.remote_address(),
+ client_id,
+ connection.stable_id()
+ );
+
+ // Create a close channel for this client
+ let (close_sender, mut close_receiver): (
+ tokio::sync::broadcast::Sender<()>,
+ tokio::sync::broadcast::Receiver<()>,
+ ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
+
+ // Create an ordered reliable send channel for this client
+ let (to_client_sender, mut to_client_receiver) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let send_stream = connection.open_uni().await.expect(
+ format!(
+ "Failed to open unidirectional send stream for client: {}",
+ client_id
+ )
+ .as_str(),
+ );
+
+ let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
+
+ let to_sync_server_clone = to_sync_server.clone();
+ let close_sender_clone = close_sender.clone();
+ let _network_broadcaster = tokio::spawn(async move {
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("Unidirectional send stream forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ while let Some(msg_bytes) = to_client_receiver.recv().await {
+ // TODO Perf: Batch frames for a send_all
+ // TODO Clean: Error handling
+ if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
+ error!("Error while sending to client {}: {}", client_id, err);
+ error!("Client {} seems disconnected, closing resources", client_id);
+ if let Err(_) = close_sender_clone.send(()) {
+ error!("Failed to close all client streams & resources for client {}", client_id)
+ }
+ to_sync_server_clone.send(
+ InternalAsyncMessage::ClientLostConnection(client_id))
+ .await
+ .expect("Failed to signal connection lost to sync server");
+ };
+ }
+ } => {}
+ }
+ });
+
+ // Signal the sync server of this new connection
+ // let mut sync_msg_receiver = to_async_server.subscribe();
+ to_sync_server
+ .send(InternalAsyncMessage::ClientConnected(ClientConnection {
+ client_id: client_id,
+ sender: to_client_sender,
+ close_sender: close_sender.clone(),
+ }))
+ .await
+ .expect("Failed to signal connection to sync client");
+
+ // Wait for the sync server to acknowledge the connection before spawning reception tasks.
+ while let Ok(message) = from_sync_server.recv().await {
+ match message {
+ InternalSyncMessage::ClientConnectedAck(id) => {
+ if id == client_id {
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Spawn a task to listen for stream opened from this client
+ let from_client_sender_clone = from_clients_sender.clone();
+ let mut uni_receivers: JoinSet<()> = JoinSet::new();
+ let mut close_receiver = close_sender.subscribe();
+ let _client_receiver = tokio::spawn(async move {
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("New Stream listener forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ // For each new stream opened by the client
+ while let Some(Ok(recv)) = new_connection.uni_streams.next().await {
+ let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
+
+ // Spawn a task to receive data on this stream.
+ let from_client_sender = from_client_sender_clone.clone();
+ uni_receivers.spawn(async move {
+ while let Some(Ok(msg_bytes)) = frame_recv.next().await {
+ from_client_sender
+ .send(ClientPayload {
+ client_id: client_id,
+ msg: msg_bytes.into(),
+ })
+ .await
+ .unwrap();// TODO Fix: error event
+ }
+ trace!("Unidirectional stream receiver ended for client: {}", client_id)
+ });
+ }
+ } => {
+ trace!("New Stream listener ended for client: {}", client_id)
+ }
+ }
+ uni_receivers.shutdown().await;
+ trace!(
+ "All unidirectional stream receivers cleaned for client: {}",
+ client_id
+ )
+ });
+ }
+}
+
+fn start_async_server(mut commands: Commands, runtime: Res<Runtime>) {
// TODO Clean: Configure size
let (from_clients_sender, from_clients_receiver) =
mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
let (to_sync_server, from_async_server) =
mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
- let (to_async_server, _) =
+ let (to_async_server, mut from_sync_server) =
broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
commands.insert_resource(Server {
clients: HashMap::new(),
receiver: from_clients_receiver,
internal_receiver: from_async_server,
- internal_sender: to_async_server.clone()
+ internal_sender: to_async_server.clone(),
});
-
- // Create server task
+ // Create async server task
runtime.spawn(async move {
- let mut client_gen_id: ClientId = 0;
- let mut client_id_mappings = HashMap::new();
-
- let (_endpoint, mut incoming) =
- Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint");
-
- // Start iterating over incoming connections/clients.
- while let Some(conn) = incoming.next().await {
- let mut new_connection: NewConnection =
- conn.await.expect("Failed to handle incoming connection");
- let connection = new_connection.connection;
-
- // Attribute an id to this client
- client_gen_id += 1; // TODO Fix: Better id generation/check
- let client_id = client_gen_id;
- client_id_mappings.insert(connection.stable_id(), client_id);
-
- info!(
- "New connection from {}, client_id: {}, stable_id : {}",
- connection.remote_address(),
- client_id,
- connection.stable_id()
- );
-
- // Create a close channel for this client
- let (close_sender, mut close_receiver): (
- tokio::sync::broadcast::Sender<()>,
- tokio::sync::broadcast::Receiver<()>,
- ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
-
- // Create an ordered reliable send channel for this client
- let (to_client_sender, mut to_client_receiver) =
- mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
-
- let send_stream = connection
- .open_uni()
- .await
- .expect( format!("Failed to open unidirectional send stream for client: {}", client_id).as_str());
-
- let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
-
- let to_sync_server_clone = to_sync_server.clone();
- let close_sender_clone = close_sender.clone();
- let _network_broadcaster = tokio::spawn(async move {
- tokio::select! {
- _ = close_receiver.recv() => {
- trace!("Unidirectional send stream forced to disconnected for client: {}", client_id)
- }
- _ = async {
- while let Some(msg_bytes) = to_client_receiver.recv().await {
- // TODO Perf: Batch frames for a send_all
- // TODO Clean: Error handling
- if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
- error!("Error while sending to client {}: {}", client_id, err);
- error!("Client {} seems disconnected, closing resources", client_id);
- if let Err(_) = close_sender_clone.send(()) {
- error!("Failed to close all client streams & resources for client {}", client_id)
- }
- to_sync_server_clone.send(
- InternalAsyncMessage::ClientLostConnection(client_id))
- .await
- .expect("Failed to signal connection lost to sync server");
- };
- }
- } => {}
- }
- });
-
- // Signal the sync server of this new connection
- let mut sync_msg_receiver = to_async_server.subscribe();
- to_sync_server.send(
- InternalAsyncMessage::ClientConnected(ClientConnection {
- client_id: client_id,
- sender: to_client_sender,
- close_sender: close_sender.clone()
- }))
- .await
- .expect("Failed to signal connection to sync client");
-
- // Wait for the sync server to acknowledge the connection before spawning reception tasks.
- while let Ok(message) = sync_msg_receiver.recv().await {
- match message {
- InternalSyncMessage::ClientConnectedAck(id) =>{
- if id == client_id {break;}
- },
+ // Wait for the sync server to to signal us to start listening
+ while let Ok(message) = from_sync_server.recv().await {
+ match message {
+ InternalSyncMessage::StartListening { config, cert_mode } => {
+ connections_listening_task(
+ config,
+ cert_mode,
+ to_sync_server.clone(),
+ to_async_server.subscribe(),
+ from_clients_sender.clone(),
+ )
+ .await;
}
+ _ => {}
}
-
- // Spawn a task to listen for stream opened from this client
- let from_client_sender_clone = from_clients_sender.clone();
- let mut uni_receivers:JoinSet<()> = JoinSet::new();
- let mut close_receiver = close_sender.subscribe();
- let _client_receiver = tokio::spawn(async move {
- tokio::select! {
- _ = close_receiver.recv() => {
- trace!("New Stream listener forced to disconnected for client: {}", client_id)
- }
- _ = async {
- // For each new stream opened by the client
- while let Some(Ok(recv)) = new_connection.uni_streams.next().await {
- let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
-
- // Spawn a task to receive data on this stream.
- let from_client_sender = from_client_sender_clone.clone();
- uni_receivers.spawn(async move {
- while let Some(Ok(msg_bytes)) = frame_recv.next().await {
- from_client_sender
- .send(ClientPayload {
- client_id: client_id,
- msg: msg_bytes.into(),
- })
- .await
- .unwrap();// TODO Fix: error event
- }
- trace!("Unidirectional stream receiver ended for client: {}", client_id)
- });
- }
- } => {
- trace!("New Stream listener ended for client: {}", client_id)
- }
- }
- uni_receivers.shutdown().await;
- trace!("All unidirectional stream receivers cleaned for client: {}", client_id)
- });
}
});
}
@@ -366,20 +431,23 @@ fn start_server(
fn update_sync_server(
mut server: ResMut<Server>,
mut connection_events: EventWriter<ConnectionEvent>,
- mut connection_lost_events: EventWriter<ConnectionLostEvent>
+ mut connection_lost_events: EventWriter<ConnectionLostEvent>,
) {
while let Ok(message) = server.internal_receiver.try_recv() {
match message {
InternalAsyncMessage::ClientConnected(connection) => {
let id = connection.client_id;
server.clients.insert(id, connection);
- server.internal_sender.send(InternalSyncMessage::ClientConnectedAck(id)).unwrap();
+ server
+ .internal_sender
+ .send(InternalSyncMessage::ClientConnectedAck(id))
+ .unwrap();
connection_events.send(ConnectionEvent { id: id });
}
InternalAsyncMessage::ClientLostConnection(client_id) => {
server.clients.remove(&client_id);
connection_lost_events.send(ConnectionLostEvent { id: client_id });
- },
+ }
}
}
}
@@ -402,7 +470,7 @@ impl Plugin for QuinnetServerPlugin {
)
.add_event::<ConnectionEvent>()
.add_event::<ConnectionLostEvent>()
- .add_startup_system_to_stage(StartupStage::PreStartup, start_server)
- .add_system_to_stage(CoreStage::PreUpdate,update_sync_server);
+ .add_startup_system_to_stage(StartupStage::PreStartup, start_async_server)
+ .add_system_to_stage(CoreStage::PreUpdate, update_sync_server);
}
}