diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 269 | ||||
-rw-r--r-- | src/lib.rs | 30 | ||||
-rw-r--r-- | src/server.rs | 276 |
3 files changed, 575 insertions, 0 deletions
diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8d4ea8e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,269 @@ +use std::{net::SocketAddr, sync::Arc}; + +use bevy::prelude::*; +use bytes::Bytes; +use futures::sink::SinkExt; +use futures_util::StreamExt; +use quinn::{ClientConfig, Endpoint}; +use serde::Deserialize; +use tokio::{ + runtime::Runtime, + sync::mpsc::{ + self, + error::{TryRecvError, TrySendError}, + }, +}; +use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; + +use crate::{QuinnetError, DEFAULT_MESSAGE_QUEUE_SIZE}; + +pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 10; + +#[derive(Deserialize)] +pub struct ClientConfigurationData { + server_host: String, + server_port: u16, + local_bind_host: String, + local_bind_port: u16, +} + +#[derive(Debug, PartialEq, Eq)] +enum ClientState { + Connecting, + Connected, +} + +#[derive(Debug)] +pub(crate) enum InternalAsyncMessage { + Connected, +} + +#[derive(Debug)] +pub(crate) enum InternalSyncMessage { + Connect, +} + +pub struct Client { + state: ClientState, + // TODO Perf: multiple channels + sender: mpsc::Sender<Bytes>, + receiver: mpsc::Receiver<Bytes>, + + pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>, + pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>, +} + +impl Client { + pub fn connect(&self) -> Result<(), QuinnetError> { + match self.internal_sender.try_send(InternalSyncMessage::Connect) { + Ok(_) => Ok(()), + Err(_) => Err(QuinnetError::FullQueue), + } + } + + pub fn receive_message<T: serde::de::DeserializeOwned>( + &mut self, + ) -> Result<Option<T>, QuinnetError> { + match self.receive_payload()? { + Some(payload) => match bincode::deserialize(&payload) { + Ok(msg) => Ok(Some(msg)), + Err(_) => Err(QuinnetError::Deserialization), + }, + None => Ok(None), + } + } + + pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> { + match bincode::serialize(&message) { + Ok(payload) => self.send_payload(payload), + Err(_) => Err(QuinnetError::Serialization), + } + } + + pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> { + match self.sender.try_send(payload.into()) { + Ok(_) => Ok(()), + Err(err) => match err { + TrySendError::Full(_) => Err(QuinnetError::FullQueue), + TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed), + }, + } + } + + pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> { + match self.receiver.try_recv() { + Ok(msg_payload) => Ok(Some(msg_payload)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed), + }, + } + } + + pub fn is_connected(&self) -> bool { + return self.state == ClientState::Connected; + } +} + +// Implementation of `ServerCertVerifier` that verifies everything as trustworthy. +struct SkipServerVerification; + +impl SkipServerVerification { + fn new() -> Arc<Self> { + Arc::new(Self) + } +} + +impl rustls::client::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator<Item = &[u8]>, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result<rustls::client::ServerCertVerified, rustls::Error> { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} + +fn configure_client() -> ClientConfig { + let crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + + ClientConfig::new(Arc::new(crypto)) +} + +fn initialize_client( + mut commands: Commands, + runtime: Res<Runtime>, + config: Res<ClientConfigurationData>, +) { + let server_adr_str = format!("{}:{}", config.server_host, config.server_port); + let srv_host = config.server_host.clone(); + let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port); + + info!("Trying to connect to server on: {} ...", server_adr_str); + + let server_addr: SocketAddr = server_adr_str + .parse() + .expect("Failed to parse server address"); + + let client_cfg = configure_client(); + + let (from_server_sender, from_server_receiver) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (to_server_sender, mut to_server_receiver) = + mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + + let (to_sync_client, from_async_client) = + mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_async_client, mut from_sync_client) = + mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + + runtime.spawn(async move { + // Wait for a connection signal before starting client + if let Some(message) = from_sync_client.recv().await { + match message { + InternalSyncMessage::Connect => info!("Client requested to connect"), + } + } else { + warn!("Client closed before requesting a connection"); + return; + } + + let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap()) + .expect("Failed to create client endpoint"); + endpoint.set_default_client_config(client_cfg); + + let mut new_connection = endpoint + .connect(server_addr, &srv_host) // TODO Clean: error handling + .expect("Failed to connect: configuration error") + .await + .expect("Failed to connect"); + info!( + "Connected to {}", + new_connection.connection.remote_address() + ); + + to_sync_client + .send(InternalAsyncMessage::Connected) + .await + .expect("Failed to signal connection to sync client"); + + let send = new_connection + .connection + .open_uni() + .await + .expect("Failed to open send stream"); + let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new()); + + let network_sends = async move { + loop { + if let Some(msg_bytes) = to_server_receiver.recv().await { + if let Err(err) = frame_send.send(msg_bytes).await { + error!("Error sending {}", err) // TODO Fix: error event + } + } + } + }; + + let network_reads = async move { + while let Some(Ok(recv)) = new_connection.uni_streams.next().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + let from_server_sender = from_server_sender.clone(); + tokio::spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_server_sender.send(msg_bytes.into()).await.unwrap(); + // TODO Fix: error event + } + }); + } + }; + + tokio::join!(network_sends, network_reads); + }); + + commands.insert_resource(Client { + state: ClientState::Connecting, + sender: to_server_sender, + receiver: from_server_receiver, + internal_receiver: from_async_client, + internal_sender: to_async_client, + }); +} + +// Receive messages from the async client tasks and update the sync client. +fn update_sync_client(mut client: ResMut<Client>) { + while let Ok(message) = client.internal_receiver.try_recv() { + match message { + // TODO Clean: Raise a connected event + InternalAsyncMessage::Connected => client.state = ClientState::Connected, + } + } +} + +pub struct QuinnetClientPlugin {} + +impl Default for QuinnetClientPlugin { + fn default() -> Self { + Self {} + } +} + +impl Plugin for QuinnetClientPlugin { + fn build(&self, app: &mut App) { + app.insert_resource( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ) + // StartupStage::PreStartup so that resources created in commands are available to default startup_systems + .add_startup_system_to_stage(StartupStage::PreStartup, initialize_client) + .add_system(update_sync_client); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..615806a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,30 @@ +use serde::{Deserialize, Serialize}; + +pub const DEFAULT_MESSAGE_QUEUE_SIZE: usize = 150; +pub const DEFAULT_KILL_MESSAGE_QUEUE_SIZE: usize = 10; +pub const DEFAULT_KEEP_ALIVE_INTERVAL_S: u64 = 4; + +pub mod client; +pub mod server; + +pub type ClientId = u64; + +/// Enum with possibles errors that can occur. +#[derive(Debug)] +pub enum QuinnetError { + /// Failed serialization + Serialization, + /// Failed deserialization + Deserialization, + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + FullQueue, + /// The receive half of the channel was explicitly closed or has been + /// dropped. + ChannelClosed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum ServerMessage { + AssignId { client_id: ClientId }, +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..9fa010d --- /dev/null +++ b/src/server.rs @@ -0,0 +1,276 @@ +use std::{ + collections::{HashMap}, + error::Error, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; + +use bevy::prelude::*; +use bytes::Bytes; +use futures::{ + sink::SinkExt, +}; +use futures_util::StreamExt; +use quinn::{Endpoint, NewConnection, ServerConfig}; +use serde::Deserialize; +use tokio::{ + runtime::Runtime, + sync::{ + broadcast::{self, error::SendError}, + mpsc::{ + self, + error::{TryRecvError}, + }, + }, + task::{ JoinSet}, +}; +use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; + +use crate::{ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_MESSAGE_QUEUE_SIZE, DEFAULT_KILL_MESSAGE_QUEUE_SIZE}; + +#[derive(Deserialize)] +pub struct ServerConfigurationData { + host: String, + port: u16, + local_bind_host: String, +} + +#[derive(Debug)] +pub struct ClientPayload { + client_id: ClientId, + msg: Bytes, +} + +pub struct Server { + broadcast_sender: broadcast::Sender<Bytes>, + receiver: mpsc::Receiver<ClientPayload>, + close_senders: Arc<Mutex<HashMap<ClientId, broadcast::Sender<()>>>>, +} + +impl Server { + pub fn disconnect_client(&mut self, client_id: ClientId) { + match self.close_senders.lock() { + Ok(senders) => { + match senders.get(&client_id) { + Some(close_sender) => + if let Err(_) = close_sender.send(()) { + error!("Failed to close client streams & resources while disconnecting client {}", client_id) + } , + None => warn!("Tried to disconnect unknown client {}", client_id), + } + }, + Err(_) => error!("Failed to acquire lock while disconnecting client {}", client_id), + } + } + + pub fn receive_message<T: serde::de::DeserializeOwned>( + &mut self, + ) -> Result<Option<(T, ClientId)>, QuinnetError> { + match self.receive_payload()? { + Some(client_msg) => match bincode::deserialize(&client_msg.msg) { + Ok(msg) => Ok(Some((msg, client_msg.client_id))), + Err(_) => Err(QuinnetError::Deserialization), + }, + None => Ok(None), + } + } + + pub fn broadcast_message<T: serde::Serialize>( + &mut self, + message: T, + ) -> Result<(), QuinnetError> { + match bincode::serialize(&message) { + Ok(payload) => self.broadcast_payload(payload), + Err(_) => Err(QuinnetError::Serialization), + } + } + + pub fn broadcast_payload<T: Into<Bytes>>(&mut self, payload: T) -> Result<(), QuinnetError> { + match self.broadcast_sender.send(payload.into()) { + Ok(_) => Ok(()), + Err(err) => match err { + SendError(_) => Err(QuinnetError::ChannelClosed), + } + } + } + + //TODO Clean: Consider receiving payloads for a specified client + pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> { + match self.receiver.try_recv() { + Ok(msg) => Ok(Some(msg)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed), + }, + } + } +} + +pub struct QuinnetServerPlugin {} + +impl Default for QuinnetServerPlugin { + fn default() -> Self { + Self {} + } +} + +impl Plugin for QuinnetServerPlugin { + fn build(&self, app: &mut App) { + app.insert_resource( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ) + .add_startup_system_to_stage(StartupStage::PreStartup, start_server); + } +} + +/// Returns default server configuration along with its certificate. +#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 +fn configure_server(server_host: &String) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> { + let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = rustls::PrivateKey(cert.serialize_private_key_der()); + let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + + let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + Arc::get_mut(&mut server_config.transport) + .unwrap() + .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); + + Ok((server_config, cert_der)) +} + +fn start_server( + mut commands: Commands, + runtime: Res<Runtime>, + config: Res<ServerConfigurationData>, +) { + // TODO Fix: handle unwraps + let server_adr_str = format!("{}:{}", config.local_bind_host, config.port); + info!("Starting server on: {} ...", server_adr_str); + + let server_addr: SocketAddr = server_adr_str + .parse() + .expect("Failed to parse server address"); + + // TODO Security: Server certificate + let (server_config, server_cert) = + configure_server(&config.host).expect("Failed to configure server"); + + // TODO Clean: Configure size + let (from_clients_sender, from_clients_receiver) = + mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (broadcast_channel_sender, _) = + broadcast::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); + + let close_senders = Arc::new(Mutex::new(HashMap::new())); + commands.insert_resource(Server { + broadcast_sender: broadcast_channel_sender.clone(), + receiver: from_clients_receiver, + close_senders: close_senders.clone() + }); + + // Create server task + runtime.spawn(async move { + let mut client_gen_id: ClientId = 0; + let mut client_id_mappings = HashMap::new(); + + let (_endpoint, mut incoming) = + Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint"); + + // Start iterating over incoming connections/clients. + while let Some(conn) = incoming.next().await { + let mut new_connection: NewConnection = + conn.await.expect("Failed to handle incoming connection"); + let connection = new_connection.connection; + + // Attribute an id to this client + client_gen_id += 1; // TODO Fix: Better id generation/check + let client_id = client_gen_id; + client_id_mappings.insert(connection.stable_id(), client_id); + // TODO Clean: Raise a connection event to the sync side. + + info!( + "New connection from {}, client_id: {}, stable_id : {}", + connection.remote_address(), + client_id, + connection.stable_id() + ); + + // Create a close channel for this client + let (close_sender, mut close_receiver): ( + tokio::sync::broadcast::Sender<()>, + tokio::sync::broadcast::Receiver<()>, + ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); + + // Create the broadcast stream + let send_stream = connection + .open_uni() + .await + .expect( format!("Failed to open broadcast send stream for client: {}", client_id).as_str()); + let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); + let mut broadcast_channel_receiver = broadcast_channel_sender.subscribe(); + let _network_broadcaster = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("Broadcaster forced to disconnected for client: {}", client_id) + } + _ = async { + while let Ok(msg_bytes) = broadcast_channel_receiver.recv().await { + // TODO Perf: Batch frames for a send_all + // TODO Clean: Error handling + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while broadcasting to client {}: {}", client_id, err); + }; + } + } => {} + } + }); + + // Spawn a task to listen for stream opened from this client + let from_client_sender_clone = from_clients_sender.clone(); + let mut uni_receivers:JoinSet<()> = JoinSet::new(); + let mut close_receiver = close_sender.subscribe(); + let _client_receiver = tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("New Stream listener forced to disconnected for client: {}", client_id) + } + _ = async { + // For each new stream opened by the client + while let Some(Ok(recv)) = new_connection.uni_streams.next().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + + // Spawn a task to receive data on this stream. + let from_client_sender = from_client_sender_clone.clone(); + uni_receivers.spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_client_sender + .send(ClientPayload { + client_id: client_id, + msg: msg_bytes.into(), + }) + .await + .unwrap();// TODO Fix: error event + } + trace!("Unidirectional stream receiver ended for client: {}", client_id) + }); + } + } => { + trace!("New Stream listener ended for client: {}", client_id) + } + } + uni_receivers.shutdown().await; + trace!("All unidirectional stream receivers cleaned for client: {}", client_id) + }); + + { + let mut close_senders = close_senders.lock().unwrap(); + close_senders.insert(client_id, close_sender); + } + } + }); +} |