aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client.rs269
-rw-r--r--src/lib.rs30
-rw-r--r--src/server.rs276
3 files changed, 575 insertions, 0 deletions
diff --git a/src/client.rs b/src/client.rs
new file mode 100644
index 0000000..8d4ea8e
--- /dev/null
+++ b/src/client.rs
@@ -0,0 +1,269 @@
+use std::{net::SocketAddr, sync::Arc};
+
+use bevy::prelude::*;
+use bytes::Bytes;
+use futures::sink::SinkExt;
+use futures_util::StreamExt;
+use quinn::{ClientConfig, Endpoint};
+use serde::Deserialize;
+use tokio::{
+ runtime::Runtime,
+ sync::mpsc::{
+ self,
+ error::{TryRecvError, TrySendError},
+ },
+};
+use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
+
+use crate::{QuinnetError, DEFAULT_MESSAGE_QUEUE_SIZE};
+
+pub const DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE: usize = 10;
+
+#[derive(Deserialize)]
+pub struct ClientConfigurationData {
+ server_host: String,
+ server_port: u16,
+ local_bind_host: String,
+ local_bind_port: u16,
+}
+
+#[derive(Debug, PartialEq, Eq)]
+enum ClientState {
+ Connecting,
+ Connected,
+}
+
+#[derive(Debug)]
+pub(crate) enum InternalAsyncMessage {
+ Connected,
+}
+
+#[derive(Debug)]
+pub(crate) enum InternalSyncMessage {
+ Connect,
+}
+
+pub struct Client {
+ state: ClientState,
+ // TODO Perf: multiple channels
+ sender: mpsc::Sender<Bytes>,
+ receiver: mpsc::Receiver<Bytes>,
+
+ pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>,
+ pub(crate) internal_sender: mpsc::Sender<InternalSyncMessage>,
+}
+
+impl Client {
+ pub fn connect(&self) -> Result<(), QuinnetError> {
+ match self.internal_sender.try_send(InternalSyncMessage::Connect) {
+ Ok(_) => Ok(()),
+ Err(_) => Err(QuinnetError::FullQueue),
+ }
+ }
+
+ pub fn receive_message<T: serde::de::DeserializeOwned>(
+ &mut self,
+ ) -> Result<Option<T>, QuinnetError> {
+ match self.receive_payload()? {
+ Some(payload) => match bincode::deserialize(&payload) {
+ Ok(msg) => Ok(Some(msg)),
+ Err(_) => Err(QuinnetError::Deserialization),
+ },
+ None => Ok(None),
+ }
+ }
+
+ pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> {
+ match bincode::serialize(&message) {
+ Ok(payload) => self.send_payload(payload),
+ Err(_) => Err(QuinnetError::Serialization),
+ }
+ }
+
+ pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> {
+ match self.sender.try_send(payload.into()) {
+ Ok(_) => Ok(()),
+ Err(err) => match err {
+ TrySendError::Full(_) => Err(QuinnetError::FullQueue),
+ TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed),
+ },
+ }
+ }
+
+ pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> {
+ match self.receiver.try_recv() {
+ Ok(msg_payload) => Ok(Some(msg_payload)),
+ Err(err) => match err {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
+ },
+ }
+ }
+
+ pub fn is_connected(&self) -> bool {
+ return self.state == ClientState::Connected;
+ }
+}
+
+// Implementation of `ServerCertVerifier` that verifies everything as trustworthy.
+struct SkipServerVerification;
+
+impl SkipServerVerification {
+ fn new() -> Arc<Self> {
+ Arc::new(Self)
+ }
+}
+
+impl rustls::client::ServerCertVerifier for SkipServerVerification {
+ fn verify_server_cert(
+ &self,
+ _end_entity: &rustls::Certificate,
+ _intermediates: &[rustls::Certificate],
+ _server_name: &rustls::ServerName,
+ _scts: &mut dyn Iterator<Item = &[u8]>,
+ _ocsp_response: &[u8],
+ _now: std::time::SystemTime,
+ ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
+ Ok(rustls::client::ServerCertVerified::assertion())
+ }
+}
+
+fn configure_client() -> ClientConfig {
+ let crypto = rustls::ClientConfig::builder()
+ .with_safe_defaults()
+ .with_custom_certificate_verifier(SkipServerVerification::new())
+ .with_no_client_auth();
+
+ ClientConfig::new(Arc::new(crypto))
+}
+
+fn initialize_client(
+ mut commands: Commands,
+ runtime: Res<Runtime>,
+ config: Res<ClientConfigurationData>,
+) {
+ let server_adr_str = format!("{}:{}", config.server_host, config.server_port);
+ let srv_host = config.server_host.clone();
+ let local_bind_adr = format!("{}:{}", config.local_bind_host, config.local_bind_port);
+
+ info!("Trying to connect to server on: {} ...", server_adr_str);
+
+ let server_addr: SocketAddr = server_adr_str
+ .parse()
+ .expect("Failed to parse server address");
+
+ let client_cfg = configure_client();
+
+ let (from_server_sender, from_server_receiver) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let (to_server_sender, mut to_server_receiver) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+
+ let (to_sync_client, from_async_client) =
+ mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+ let (to_async_client, mut from_sync_client) =
+ mpsc::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+
+ runtime.spawn(async move {
+ // Wait for a connection signal before starting client
+ if let Some(message) = from_sync_client.recv().await {
+ match message {
+ InternalSyncMessage::Connect => info!("Client requested to connect"),
+ }
+ } else {
+ warn!("Client closed before requesting a connection");
+ return;
+ }
+
+ let mut endpoint = Endpoint::client(local_bind_adr.parse().unwrap())
+ .expect("Failed to create client endpoint");
+ endpoint.set_default_client_config(client_cfg);
+
+ let mut new_connection = endpoint
+ .connect(server_addr, &srv_host) // TODO Clean: error handling
+ .expect("Failed to connect: configuration error")
+ .await
+ .expect("Failed to connect");
+ info!(
+ "Connected to {}",
+ new_connection.connection.remote_address()
+ );
+
+ to_sync_client
+ .send(InternalAsyncMessage::Connected)
+ .await
+ .expect("Failed to signal connection to sync client");
+
+ let send = new_connection
+ .connection
+ .open_uni()
+ .await
+ .expect("Failed to open send stream");
+ let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new());
+
+ let network_sends = async move {
+ loop {
+ if let Some(msg_bytes) = to_server_receiver.recv().await {
+ if let Err(err) = frame_send.send(msg_bytes).await {
+ error!("Error sending {}", err) // TODO Fix: error event
+ }
+ }
+ }
+ };
+
+ let network_reads = async move {
+ while let Some(Ok(recv)) = new_connection.uni_streams.next().await {
+ let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
+ let from_server_sender = from_server_sender.clone();
+ tokio::spawn(async move {
+ while let Some(Ok(msg_bytes)) = frame_recv.next().await {
+ from_server_sender.send(msg_bytes.into()).await.unwrap();
+ // TODO Fix: error event
+ }
+ });
+ }
+ };
+
+ tokio::join!(network_sends, network_reads);
+ });
+
+ commands.insert_resource(Client {
+ state: ClientState::Connecting,
+ sender: to_server_sender,
+ receiver: from_server_receiver,
+ internal_receiver: from_async_client,
+ internal_sender: to_async_client,
+ });
+}
+
+// Receive messages from the async client tasks and update the sync client.
+fn update_sync_client(mut client: ResMut<Client>) {
+ while let Ok(message) = client.internal_receiver.try_recv() {
+ match message {
+ // TODO Clean: Raise a connected event
+ InternalAsyncMessage::Connected => client.state = ClientState::Connected,
+ }
+ }
+}
+
+pub struct QuinnetClientPlugin {}
+
+impl Default for QuinnetClientPlugin {
+ fn default() -> Self {
+ Self {}
+ }
+}
+
+impl Plugin for QuinnetClientPlugin {
+ fn build(&self, app: &mut App) {
+ app.insert_resource(
+ tokio::runtime::Builder::new_multi_thread()
+ .enable_all()
+ .build()
+ .unwrap(),
+ )
+ // StartupStage::PreStartup so that resources created in commands are available to default startup_systems
+ .add_startup_system_to_stage(StartupStage::PreStartup, initialize_client)
+ .add_system(update_sync_client);
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..615806a
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,30 @@
+use serde::{Deserialize, Serialize};
+
+pub const DEFAULT_MESSAGE_QUEUE_SIZE: usize = 150;
+pub const DEFAULT_KILL_MESSAGE_QUEUE_SIZE: usize = 10;
+pub const DEFAULT_KEEP_ALIVE_INTERVAL_S: u64 = 4;
+
+pub mod client;
+pub mod server;
+
+pub type ClientId = u64;
+
+/// Enum with possibles errors that can occur.
+#[derive(Debug)]
+pub enum QuinnetError {
+ /// Failed serialization
+ Serialization,
+ /// Failed deserialization
+ Deserialization,
+ /// The data could not be sent on the channel because the channel is
+ /// currently full and sending would require blocking.
+ FullQueue,
+ /// The receive half of the channel was explicitly closed or has been
+ /// dropped.
+ ChannelClosed,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+enum ServerMessage {
+ AssignId { client_id: ClientId },
+}
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..9fa010d
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,276 @@
+use std::{
+ collections::{HashMap},
+ error::Error,
+ net::SocketAddr,
+ sync::{Arc, Mutex},
+ time::Duration,
+};
+
+use bevy::prelude::*;
+use bytes::Bytes;
+use futures::{
+ sink::SinkExt,
+};
+use futures_util::StreamExt;
+use quinn::{Endpoint, NewConnection, ServerConfig};
+use serde::Deserialize;
+use tokio::{
+ runtime::Runtime,
+ sync::{
+ broadcast::{self, error::SendError},
+ mpsc::{
+ self,
+ error::{TryRecvError},
+ },
+ },
+ task::{ JoinSet},
+};
+use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
+
+use crate::{ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_MESSAGE_QUEUE_SIZE, DEFAULT_KILL_MESSAGE_QUEUE_SIZE};
+
+#[derive(Deserialize)]
+pub struct ServerConfigurationData {
+ host: String,
+ port: u16,
+ local_bind_host: String,
+}
+
+#[derive(Debug)]
+pub struct ClientPayload {
+ client_id: ClientId,
+ msg: Bytes,
+}
+
+pub struct Server {
+ broadcast_sender: broadcast::Sender<Bytes>,
+ receiver: mpsc::Receiver<ClientPayload>,
+ close_senders: Arc<Mutex<HashMap<ClientId, broadcast::Sender<()>>>>,
+}
+
+impl Server {
+ pub fn disconnect_client(&mut self, client_id: ClientId) {
+ match self.close_senders.lock() {
+ Ok(senders) => {
+ match senders.get(&client_id) {
+ Some(close_sender) =>
+ if let Err(_) = close_sender.send(()) {
+ error!("Failed to close client streams & resources while disconnecting client {}", client_id)
+ } ,
+ None => warn!("Tried to disconnect unknown client {}", client_id),
+ }
+ },
+ Err(_) => error!("Failed to acquire lock while disconnecting client {}", client_id),
+ }
+ }
+
+ pub fn receive_message<T: serde::de::DeserializeOwned>(
+ &mut self,
+ ) -> Result<Option<(T, ClientId)>, QuinnetError> {
+ match self.receive_payload()? {
+ Some(client_msg) => match bincode::deserialize(&client_msg.msg) {
+ Ok(msg) => Ok(Some((msg, client_msg.client_id))),
+ Err(_) => Err(QuinnetError::Deserialization),
+ },
+ None => Ok(None),
+ }
+ }
+
+ pub fn broadcast_message<T: serde::Serialize>(
+ &mut self,
+ message: T,
+ ) -> Result<(), QuinnetError> {
+ match bincode::serialize(&message) {
+ Ok(payload) => self.broadcast_payload(payload),
+ Err(_) => Err(QuinnetError::Serialization),
+ }
+ }
+
+ pub fn broadcast_payload<T: Into<Bytes>>(&mut self, payload: T) -> Result<(), QuinnetError> {
+ match self.broadcast_sender.send(payload.into()) {
+ Ok(_) => Ok(()),
+ Err(err) => match err {
+ SendError(_) => Err(QuinnetError::ChannelClosed),
+ }
+ }
+ }
+
+ //TODO Clean: Consider receiving payloads for a specified client
+ pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> {
+ match self.receiver.try_recv() {
+ Ok(msg) => Ok(Some(msg)),
+ Err(err) => match err {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
+ },
+ }
+ }
+}
+
+pub struct QuinnetServerPlugin {}
+
+impl Default for QuinnetServerPlugin {
+ fn default() -> Self {
+ Self {}
+ }
+}
+
+impl Plugin for QuinnetServerPlugin {
+ fn build(&self, app: &mut App) {
+ app.insert_resource(
+ tokio::runtime::Builder::new_multi_thread()
+ .enable_all()
+ .build()
+ .unwrap(),
+ )
+ .add_startup_system_to_stage(StartupStage::PreStartup, start_server);
+ }
+}
+
+/// Returns default server configuration along with its certificate.
+#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527
+fn configure_server(server_host: &String) -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
+ let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap();
+ let cert_der = cert.serialize_der().unwrap();
+ let priv_key = rustls::PrivateKey(cert.serialize_private_key_der());
+ let cert_chain = vec![rustls::Certificate(cert_der.clone())];
+
+ let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
+ Arc::get_mut(&mut server_config.transport)
+ .unwrap()
+ .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
+
+ Ok((server_config, cert_der))
+}
+
+fn start_server(
+ mut commands: Commands,
+ runtime: Res<Runtime>,
+ config: Res<ServerConfigurationData>,
+) {
+ // TODO Fix: handle unwraps
+ let server_adr_str = format!("{}:{}", config.local_bind_host, config.port);
+ info!("Starting server on: {} ...", server_adr_str);
+
+ let server_addr: SocketAddr = server_adr_str
+ .parse()
+ .expect("Failed to parse server address");
+
+ // TODO Security: Server certificate
+ let (server_config, server_cert) =
+ configure_server(&config.host).expect("Failed to configure server");
+
+ // TODO Clean: Configure size
+ let (from_clients_sender, from_clients_receiver) =
+ mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let (broadcast_channel_sender, _) =
+ broadcast::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
+
+ let close_senders = Arc::new(Mutex::new(HashMap::new()));
+ commands.insert_resource(Server {
+ broadcast_sender: broadcast_channel_sender.clone(),
+ receiver: from_clients_receiver,
+ close_senders: close_senders.clone()
+ });
+
+ // Create server task
+ runtime.spawn(async move {
+ let mut client_gen_id: ClientId = 0;
+ let mut client_id_mappings = HashMap::new();
+
+ let (_endpoint, mut incoming) =
+ Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint");
+
+ // Start iterating over incoming connections/clients.
+ while let Some(conn) = incoming.next().await {
+ let mut new_connection: NewConnection =
+ conn.await.expect("Failed to handle incoming connection");
+ let connection = new_connection.connection;
+
+ // Attribute an id to this client
+ client_gen_id += 1; // TODO Fix: Better id generation/check
+ let client_id = client_gen_id;
+ client_id_mappings.insert(connection.stable_id(), client_id);
+ // TODO Clean: Raise a connection event to the sync side.
+
+ info!(
+ "New connection from {}, client_id: {}, stable_id : {}",
+ connection.remote_address(),
+ client_id,
+ connection.stable_id()
+ );
+
+ // Create a close channel for this client
+ let (close_sender, mut close_receiver): (
+ tokio::sync::broadcast::Sender<()>,
+ tokio::sync::broadcast::Receiver<()>,
+ ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
+
+ // Create the broadcast stream
+ let send_stream = connection
+ .open_uni()
+ .await
+ .expect( format!("Failed to open broadcast send stream for client: {}", client_id).as_str());
+ let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
+ let mut broadcast_channel_receiver = broadcast_channel_sender.subscribe();
+ let _network_broadcaster = tokio::spawn(async move {
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("Broadcaster forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ while let Ok(msg_bytes) = broadcast_channel_receiver.recv().await {
+ // TODO Perf: Batch frames for a send_all
+ // TODO Clean: Error handling
+ if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
+ error!("Error while broadcasting to client {}: {}", client_id, err);
+ };
+ }
+ } => {}
+ }
+ });
+
+ // Spawn a task to listen for stream opened from this client
+ let from_client_sender_clone = from_clients_sender.clone();
+ let mut uni_receivers:JoinSet<()> = JoinSet::new();
+ let mut close_receiver = close_sender.subscribe();
+ let _client_receiver = tokio::spawn(async move {
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("New Stream listener forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ // For each new stream opened by the client
+ while let Some(Ok(recv)) = new_connection.uni_streams.next().await {
+ let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
+
+ // Spawn a task to receive data on this stream.
+ let from_client_sender = from_client_sender_clone.clone();
+ uni_receivers.spawn(async move {
+ while let Some(Ok(msg_bytes)) = frame_recv.next().await {
+ from_client_sender
+ .send(ClientPayload {
+ client_id: client_id,
+ msg: msg_bytes.into(),
+ })
+ .await
+ .unwrap();// TODO Fix: error event
+ }
+ trace!("Unidirectional stream receiver ended for client: {}", client_id)
+ });
+ }
+ } => {
+ trace!("New Stream listener ended for client: {}", client_id)
+ }
+ }
+ uni_receivers.shutdown().await;
+ trace!("All unidirectional stream receivers cleaned for client: {}", client_id)
+ });
+
+ {
+ let mut close_senders = close_senders.lock().unwrap();
+ close_senders.insert(client_id, close_sender);
+ }
+ }
+ });
+}