aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenauxg <19689618+Henauxg@users.noreply.github.com>2023-01-18 15:56:28 +0100
committerHenauxg <19689618+Henauxg@users.noreply.github.com>2023-01-18 15:56:28 +0100
commit0aa6c92b1bf9713aea973c414675c9f21da07781 (patch)
tree94490f2eeda7e9142964a1e248aff4348c6293a0
parentbcaba7612398a31c8f9ca7d40bcfe66b0c25a62d (diff)
[server] Change receiving API to be client specific & use shared receptino tasks
-rw-r--r--src/server.rs209
1 files changed, 46 insertions, 163 deletions
diff --git a/src/server.rs b/src/server.rs
index 0ac916b..6729728 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -7,8 +7,7 @@ use std::{
use bevy::prelude::*;
use bytes::Bytes;
-use futures_util::StreamExt;
-use quinn::{ConnectionError, Endpoint as QuinnEndpoint, RecvStream, ServerConfig};
+use quinn::{ConnectionError, Endpoint as QuinnEndpoint, ServerConfig};
use serde::Deserialize;
use tokio::{
runtime,
@@ -20,14 +19,14 @@ use tokio::{
},
},
};
-use tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use crate::{
server::certificate::retrieve_certificate,
shared::{
channel::{
- channels_task, get_channel_id_from_type, Channel, ChannelAsyncMessage, ChannelId,
- ChannelSyncMessage, ChannelType, MultiChannelId,
+ channels_task, get_channel_id_from_type, reliable_receiver_task,
+ unreliable_receiver_task, Channel, ChannelAsyncMessage, ChannelId, ChannelSyncMessage,
+ ChannelType, MultiChannelId,
},
AsyncRuntime, ClientId, QuinnetError, DEFAULT_KEEP_ALIVE_INTERVAL_S,
DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE,
@@ -98,30 +97,17 @@ impl ServerConfigurationData {
}
}
-/// Represents a client message in its binary form
-#[derive(Debug)]
-pub struct ClientPayload {
- /// Id of the client sending the message
- client_id: ClientId,
- /// Content of the message as bytes
- msg: Bytes,
-}
-
#[derive(Debug)]
pub(crate) enum ServerAsyncMessage {
ClientConnected(ClientConnection),
ClientConnectionClosed(ClientId, ConnectionError),
}
-#[derive(Debug, Clone)]
-pub(crate) enum ServerSyncMessage {
- ClientConnectedAck(ClientId),
-}
-
#[derive(Debug)]
-pub(crate) struct ClientConnection {
+pub struct ClientConnection {
client_id: ClientId,
channels: HashMap<ChannelId, Channel>,
+ bytes_from_client_recv: mpsc::Receiver<Bytes>,
close_sender: broadcast::Sender<()>,
pub(crate) to_channels_send: mpsc::Sender<ChannelSyncMessage>,
@@ -173,28 +159,35 @@ pub struct Endpoint {
channels: HashSet<ChannelId>,
default_channel: Option<ChannelId>,
last_gen_id: MultiChannelId,
- payloads_from_clients_recv: mpsc::Receiver<ClientPayload>,
close_sender: broadcast::Sender<()>,
pub(crate) from_async_server_recv: mpsc::Receiver<ServerAsyncMessage>,
- pub(crate) to_async_server_send: broadcast::Sender<ServerSyncMessage>,
}
impl Endpoint {
- pub fn receive_message<T: serde::de::DeserializeOwned>(
+ /// Returns an iterator over all client ids
+ pub fn clients(&self) -> Vec<ClientId> {
+ self.clients.keys().cloned().collect()
+ }
+
+ pub fn receive_message_from<T: serde::de::DeserializeOwned>(
&mut self,
- ) -> Result<Option<(T, ClientId)>, QuinnetError> {
- match self.receive_payload()? {
- Some(client_msg) => match bincode::deserialize(&client_msg.msg) {
- Ok(msg) => Ok(Some((msg, client_msg.client_id))),
+ client_id: ClientId,
+ ) -> Result<Option<T>, QuinnetError> {
+ match self.receive_payload_from(client_id)? {
+ Some(payload) => match bincode::deserialize(&payload) {
+ Ok(msg) => Ok(Some(msg)),
Err(_) => Err(QuinnetError::Deserialization),
},
None => Ok(None),
}
}
- pub fn try_receive_message<T: serde::de::DeserializeOwned>(&mut self) -> Option<(T, ClientId)> {
- match self.receive_message() {
+ pub fn try_receive_message_from<T: serde::de::DeserializeOwned>(
+ &mut self,
+ client_id: ClientId,
+ ) -> Option<T> {
+ match self.receive_message_from(client_id) {
Ok(message) => message,
Err(err) => {
error!("try_receive_message: {}", err);
@@ -203,18 +196,24 @@ impl Endpoint {
}
}
- pub fn receive_payload(&mut self) -> Result<Option<ClientPayload>, QuinnetError> {
- match self.payloads_from_clients_recv.try_recv() {
- Ok(msg) => Ok(Some(msg)),
- Err(err) => match err {
- TryRecvError::Empty => Ok(None),
- TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed),
+ pub fn receive_payload_from(
+ &mut self,
+ client_id: ClientId,
+ ) -> Result<Option<Bytes>, QuinnetError> {
+ match self.clients.get_mut(&client_id) {
+ Some(client) => match client.bytes_from_client_recv.try_recv() {
+ Ok(msg) => Ok(Some(msg)),
+ Err(err) => match err {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(QuinnetError::InternalChannelClosed),
+ },
},
+ None => Err(QuinnetError::UnknownClient(client_id)),
}
}
- pub fn try_receive_payload(&mut self) -> Option<ClientPayload> {
- match self.receive_payload() {
+ pub fn try_receive_payload_from(&mut self, client_id: ClientId) -> Option<Bytes> {
+ match self.receive_payload_from(client_id) {
Ok(payload) => payload,
Err(err) => {
error!("try_receive_payload: {}", err);
@@ -548,16 +547,8 @@ impl Endpoint {
connection.create_channel(*channel_id)?;
}
- match self
- .to_async_server_send
- .send(ServerSyncMessage::ClientConnectedAck(connection.client_id))
- {
- Ok(_) => {
- self.clients.insert(connection.client_id, connection);
- Ok(())
- }
- Err(_) => Err(QuinnetError::InternalChannelClosed),
- }
+ self.clients.insert(connection.client_id, connection);
+ Ok(())
}
}
}
@@ -607,12 +598,8 @@ impl Server {
.ok_or(QuinnetError::LockAcquisitionFailure)?
.keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
- let (payloads_from_clients_send, payloads_from_clients_recv) =
- mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
let (to_sync_server_send, from_async_server_recv) =
mpsc::channel::<ServerAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
- let (to_async_server_send, from_sync_server_recv) =
- broadcast::channel::<ServerSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
let (endpoint_close_send, endpoint_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
@@ -624,8 +611,6 @@ impl Server {
server_addr,
to_sync_server_send.clone(),
endpoint_close_recv,
- from_sync_server_recv,
- payloads_from_clients_send.clone(),
)
.await;
});
@@ -635,10 +620,8 @@ impl Server {
channels: HashSet::new(),
default_channel: None,
last_gen_id: 0,
- payloads_from_clients_recv,
close_sender: endpoint_close_send,
from_async_server_recv,
- to_async_server_send: to_async_server_send.clone(),
};
let ordered_reliable_id = endpoint.open_default_channels()?;
self.endpoint = Some(endpoint);
@@ -673,8 +656,6 @@ async fn endpoint_task(
endpoint_adr: SocketAddr,
to_sync_server_send: mpsc::Sender<ServerAsyncMessage>,
mut endpoint_close_recv: broadcast::Receiver<()>,
- from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>,
- payloads_from_clients_send: mpsc::Sender<ClientPayload>,
) {
let mut client_gen_id: ClientId = 0;
let mut client_id_mappings = HashMap::new();
@@ -697,15 +678,11 @@ async fn endpoint_task(
{
let to_sync_server_send = to_sync_server_send.clone();
- let from_sync_server_recv = from_sync_server_recv.resubscribe();
- let payloads_from_clients_send = payloads_from_clients_send.clone();
tokio::spawn(async move {
client_connection_task(
connection,
client_id,
- to_sync_server_send,
- from_sync_server_recv,
- payloads_from_clients_send,
+ to_sync_server_send
)
.await
});
@@ -721,8 +698,6 @@ async fn client_connection_task(
connection: quinn::Connection,
client_id: ClientId,
to_sync_server_send: mpsc::Sender<ServerAsyncMessage>,
- mut from_sync_server_recv: broadcast::Receiver<ServerSyncMessage>,
- payloads_from_clients_send: mpsc::Sender<ClientPayload>,
) {
info!(
"New connection from {}, client_id: {}",
@@ -732,6 +707,8 @@ async fn client_connection_task(
let (client_close_send, client_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
+ let (bytes_from_client_send, bytes_from_client_recv) =
+ mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
let (from_channels_send, from_channels_recv) =
mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
let (to_channels_send, to_channels_recv) =
@@ -742,6 +719,7 @@ async fn client_connection_task(
.send(ServerAsyncMessage::ClientConnected(ClientConnection {
client_id: client_id,
channels: HashMap::new(),
+ bytes_from_client_recv,
close_sender: client_close_send.clone(),
from_channels_recv,
to_channels_send,
@@ -749,13 +727,6 @@ async fn client_connection_task(
.await
.expect("Failed to signal connection to sync client");
- // Wait for the sync server to acknowledge the connection before spawning reception tasks.
- while let Ok(ServerSyncMessage::ClientConnectedAck(id)) = from_sync_server_recv.recv().await {
- if id == client_id {
- break;
- }
- }
-
// Spawn a task to listen for the underlying connection being closed
{
let conn = connection.clone();
@@ -779,13 +750,13 @@ async fn client_connection_task(
{
let connection_handle = connection.clone();
let client_close_recv = client_close_recv.resubscribe();
- let payloads_incoming_send = payloads_from_clients_send.clone();
+ let bytes_incoming_send = bytes_from_client_send.clone();
tokio::spawn(async move {
reliable_receiver_task(
client_id,
connection_handle,
client_close_recv,
- payloads_incoming_send,
+ bytes_incoming_send,
)
.await
});
@@ -795,13 +766,13 @@ async fn client_connection_task(
{
let connection_handle = connection.clone();
let client_close_recv = client_close_recv.resubscribe();
- let payloads_incoming_send = payloads_from_clients_send.clone();
+ let bytes_incoming_send = bytes_from_client_send.clone();
tokio::spawn(async move {
unreliable_receiver_task(
client_id,
connection_handle,
client_close_recv,
- payloads_incoming_send,
+ bytes_incoming_send,
)
.await
});
@@ -819,94 +790,6 @@ async fn client_connection_task(
});
}
-async fn uni_receiver_task(
- client_id: ClientId,
- mut close_recv: broadcast::Receiver<()>,
- recv: RecvStream,
- payloads_from_clients_send: mpsc::Sender<ClientPayload>,
-) {
- tokio::select! {
- _ = close_recv.recv() => {
- trace!("Listener of a Unidirectional Receiving Streams received a close signal")
- }
- _ = async {
- let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
- while let Some(Ok(msg_bytes)) = frame_recv.next().await {
- // TODO Clean: error handling
- payloads_from_clients_send
- .send(ClientPayload {
- client_id: client_id,
- msg: msg_bytes.into(),
- })
- .await
- .unwrap();
- }
- } => {}
- };
-}
-
-async fn reliable_receiver_task(
- client_id: ClientId,
- connection: quinn::Connection,
- mut close_recv: tokio::sync::broadcast::Receiver<()>,
- payloads_from_clients_send: mpsc::Sender<ClientPayload>,
-) {
- let close_recv_clone = close_recv.resubscribe();
- tokio::select! {
- _ = close_recv.recv() => {
- trace!("Listener for new Unidirectional Receiving Streams received a close signal for client: {}", client_id)
- }
- _ = async {
- while let Ok(recv) = connection.accept_uni().await {
- let payloads_from_clients_send = payloads_from_clients_send.clone();
- let close_recv_clone = close_recv_clone.resubscribe();
- tokio::spawn(async move {
- uni_receiver_task(
- client_id,
- close_recv_clone,
- recv,
- payloads_from_clients_send
- ).await;
- });
- }
- } => {
- trace!("New Stream listener ended for client: {}", client_id)
- }
- }
- trace!(
- "All unidirectional stream receivers cleaned for client: {}",
- client_id
- )
-}
-
-async fn unreliable_receiver_task(
- client_id: ClientId,
- connection: quinn::Connection,
- mut close_recv: broadcast::Receiver<()>,
- payloads_incoming_send: mpsc::Sender<ClientPayload>,
-) {
- tokio::select! {
- _ = close_recv.recv() => {
- trace!("Listener for unreliable datagrams received a close signal for client: {}",
- client_id)
- }
- _ = async {
- while let Ok(msg_bytes) = connection.read_datagram().await {
- // TODO Clean: error handling
- payloads_incoming_send.send(ClientPayload {
- client_id: client_id,
- msg: msg_bytes.into(),
- })
- .await
- .unwrap();
- }
- } => {
- trace!("Listener for unreliable datagrams ended for client: {}",
- client_id)
- }
- };
-}
-
fn create_server(mut commands: Commands, runtime: Res<AsyncRuntime>) {
commands.insert_resource(Server {
endpoint: None,