aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGilles Henaux <gill.henaux@gmail.com>2023-01-09 19:04:31 +0100
committerGitHub <noreply@github.com>2023-01-09 19:04:31 +0100
commita5921d1dd7526ed02605837b1079dc9c4934febe (patch)
treede721f8719ecee1f35b7cf73db2427da43dadc21 /src
parentc28dbcd67b6949802db5a4afcb7b350795db870a (diff)
parenta665a7adeecf6942bb17829132e0246ac263f932 (diff)
Merge pull request #6 from zheilbron/graceful_disconnect
Gracefully disconnect connections and trigger events
Diffstat (limited to 'src')
-rw-r--r--src/client.rs205
-rw-r--r--src/server.rs121
-rw-r--r--src/shared.rs2
3 files changed, 225 insertions, 103 deletions
diff --git a/src/client.rs b/src/client.rs
index 3369c1e..48255db 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -12,8 +12,8 @@ use bevy::prelude::*;
use bytes::Bytes;
use futures::sink::SinkExt;
use futures_util::StreamExt;
-use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint};
-use quinn_proto::ConnectionStats;
+use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, SendStream};
+use quinn_proto::{ConnectionStats, VarInt};
use serde::Deserialize;
use tokio::{
runtime::{self},
@@ -105,8 +105,9 @@ type InternalConnectionRef = QuinnConnection;
/// Current state of a client connection
#[derive(Debug)]
enum ConnectionState {
- Disconnected,
+ Connecting,
Connected(InternalConnectionRef),
+ Disconnected,
}
#[derive(Debug)]
@@ -171,9 +172,12 @@ impl Connection {
}
pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> {
- match bincode::serialize(&message) {
- Ok(payload) => self.send_payload(payload),
- Err(_) => Err(QuinnetError::Serialization),
+ match &self.state {
+ ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
+ _ => match bincode::serialize(&message) {
+ Ok(payload) => self.send_payload(payload),
+ Err(_) => Err(QuinnetError::Serialization),
+ },
}
}
@@ -186,11 +190,14 @@ impl Connection {
}
pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> {
- match self.sender.try_send(payload.into()) {
- Ok(_) => Ok(()),
- Err(err) => match err {
- TrySendError::Full(_) => Err(QuinnetError::FullQueue),
- TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed),
+ match &self.state {
+ ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
+ _ => match self.sender.try_send(payload.into()) {
+ Ok(_) => Ok(()),
+ Err(err) => match err {
+ TrySendError::Full(_) => Err(QuinnetError::FullQueue),
+ TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed),
+ },
},
}
}
@@ -204,11 +211,14 @@ impl Connection {
}
pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> {
- match self.receiver.try_recv() {
- Ok(msg_payload) => Ok(Some(msg_payload)),
- Err(err) => match err {
- TryRecvError::Empty => Ok(None),
- TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
+ match &self.state {
+ ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
+ _ => match self.receiver.try_recv() {
+ Ok(msg_payload) => Ok(Some(msg_payload)),
+ Err(err) => match err {
+ TryRecvError::Empty => Ok(None),
+ TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
+ },
},
}
}
@@ -226,27 +236,30 @@ impl Connection {
/// Disconnect from the server on this connection. This does not send any message to the server, and simply closes all the connection's tasks locally.
fn disconnect(&mut self) -> Result<(), QuinnetError> {
- if self.is_connected() {
- if let Err(_) = self.close_sender.send(()) {
- return Err(QuinnetError::ChannelClosed);
+ match &self.state {
+ ConnectionState::Disconnected => Ok(()),
+ _ => {
+ self.state = ConnectionState::Disconnected;
+ match self.close_sender.send(()) {
+ Ok(_) => Ok(()),
+ Err(_) => Err(QuinnetError::ChannelClosed),
+ }
}
}
- self.state = ConnectionState::Disconnected;
- Ok(())
}
pub fn is_connected(&self) -> bool {
match self.state {
- ConnectionState::Disconnected => false,
ConnectionState::Connected(_) => true,
+ _ => false,
}
}
/// Returns statistics about the current connection if connected.
pub fn stats(&self) -> Option<ConnectionStats> {
match &self.state {
- ConnectionState::Disconnected => None,
ConnectionState::Connected(connection) => Some(connection.stats()),
+ _ => None,
}
}
}
@@ -331,7 +344,7 @@ impl Client {
) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
let connection = Connection {
- state: ConnectionState::Disconnected,
+ state: ConnectionState::Connecting,
sender: to_server_sender,
receiver: from_server_receiver,
close_sender: close_sender.clone(),
@@ -474,57 +487,100 @@ async fn connection_task(mut spawn_config: ConnectionSpawnConfig) {
.expect("Failed to open send stream");
let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new());
- let close_sender_clone = spawn_config.close_sender.clone();
- let _network_sends = tokio::spawn(async move {
- tokio::select! {
- _ = spawn_config.close_receiver.recv() => {
- trace!("Unidirectional send Stream forced to disconnected")
- }
- _ = async {
- while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await {
- if let Err(err) = frame_send.send(msg_bytes).await {
- error!("Error while sending, {}", err); // TODO Clean: error handling
- error!("Client seems disconnected, closing resources");
- if let Err(_) = close_sender_clone.send(()) {
- error!("Failed to close all client streams & resources")
- }
- spawn_config.to_sync_client.send(
- InternalAsyncMessage::LostConnection)
- .await
- .expect("Failed to signal connection lost to sync client");
+ let _close_waiter = {
+ let conn = connection.clone();
+ let to_sync_client = spawn_config.to_sync_client.clone();
+ let close_sender = spawn_config.close_sender.clone();
+ tokio::spawn(async move {
+ let conn_err = conn.closed().await;
+ info!("Disconnected: {}", conn_err);
+ close_sender.send(()).ok();
+ to_sync_client
+ .send(InternalAsyncMessage::LostConnection)
+ .await
+ .expect("Failed to signal connection lost to sync client");
+ })
+ };
+
+ let _network_sends = {
+ let close_sender_clone = spawn_config.close_sender.clone();
+ let conn = connection.clone();
+ tokio::spawn(async move {
+ tokio::select! {
+ _ = spawn_config.close_receiver.recv() => {
+ trace!("Unidirectional send Stream forced to disconnected")
+ }
+ _ = async {
+ while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await {
+ send_msg(&close_sender_clone, &spawn_config.to_sync_client, &mut frame_send, msg_bytes).await;
}
+ } => {
+ trace!("Unidirectional send Stream ended")
}
- } => {
- trace!("Unidirectional send Stream ended")
}
- }
- });
-
- let mut uni_receivers: JoinSet<()> = JoinSet::new();
- let mut close_receiver = spawn_config.close_sender.subscribe();
- let _network_reads = tokio::spawn(async move {
- tokio::select! {
- _ = close_receiver.recv() => {
- trace!("New Stream listener forced to disconnected")
+ while let Ok(msg_bytes) = spawn_config.to_server_receiver.try_recv() {
+ if let Err(err) = frame_send.send(msg_bytes).await {
+ error!("Error while sending, {}", err);
+ }
}
- _ = async {
- while let Ok(recv)= connection.accept_uni().await {
- let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
- let from_server_sender = spawn_config.from_server_sender.clone();
-
- uni_receivers.spawn(async move {
- while let Some(Ok(msg_bytes)) = frame_recv.next().await {
- from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling
- }
- });
+ if let Err(err) = frame_send.flush().await {
+ error!("Error while flushing stream: {}", err);
+ }
+ if let Err(err) = frame_send.into_inner().finish().await {
+ error!("Failed to shutdown stream gracefully: {}", err);
+ }
+ conn.close(VarInt::from_u32(0), "closed".as_bytes());
+ })
+ };
+
+ let _network_reads = {
+ let mut uni_receivers: JoinSet<()> = JoinSet::new();
+ let mut close_receiver = spawn_config.close_sender.subscribe();
+ tokio::spawn(async move {
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("New Stream listener forced to disconnected")
+ }
+ _ = async {
+ while let Ok(recv)= connection.accept_uni().await {
+ let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
+ let from_server_sender = spawn_config.from_server_sender.clone();
+
+ uni_receivers.spawn(async move {
+ while let Some(Ok(msg_bytes)) = frame_recv.next().await {
+ from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling
+ }
+ });
+ }
+ } => {
+ trace!("New Stream listener ended ")
}
- } => {
- trace!("New Stream listener ended ")
}
- }
- uni_receivers.shutdown().await;
- trace!("All unidirectional stream receivers cleaned");
- });
+ uni_receivers.shutdown().await;
+ trace!("All unidirectional stream receivers cleaned");
+ })
+ };
+ }
+ }
+}
+
+async fn send_msg(
+ close_sender: &tokio::sync::broadcast::Sender<()>,
+ to_sync_client: &mpsc::Sender<InternalAsyncMessage>,
+ frame_send: &mut FramedWrite<SendStream, LengthDelimitedCodec>,
+ msg_bytes: Bytes,
+) {
+ if let Err(err) = frame_send.send(msg_bytes).await {
+ error!("Error while sending, {}", err);
+ error!("Client seems disconnected, closing resources");
+ // Emit LostConnection to properly update the connection about its state.
+ // Raise LostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed.
+ to_sync_client
+ .send(InternalAsyncMessage::LostConnection)
+ .await
+ .expect("Failed to signal connection lost to sync client");
+ if let Err(_) = close_sender.send(()) {
+ error!("Failed to close all client streams & resources")
}
}
}
@@ -545,10 +601,13 @@ fn update_sync_client(
connection.state = ConnectionState::Connected(internal_connection);
connection_events.send(ConnectionEvent { id: *connection_id });
}
- InternalAsyncMessage::LostConnection => {
- connection.state = ConnectionState::Disconnected;
- connection_lost_events.send(ConnectionLostEvent { id: *connection_id });
- }
+ InternalAsyncMessage::LostConnection => match connection.state {
+ ConnectionState::Disconnected => (),
+ _ => {
+ connection.state = ConnectionState::Disconnected;
+ connection_lost_events.send(ConnectionLostEvent { id: *connection_id });
+ }
+ },
InternalAsyncMessage::CertificateInteractionRequest {
status,
info,
diff --git a/src/server.rs b/src/server.rs
index bde4419..f89ad58 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -4,7 +4,8 @@ use bevy::prelude::*;
use bytes::Bytes;
use futures::sink::SinkExt;
use futures_util::StreamExt;
-use quinn::{Endpoint as QuinnEndpoint, ServerConfig};
+use quinn::{Endpoint as QuinnEndpoint, SendStream, ServerConfig};
+use quinn_proto::VarInt;
use serde::Deserialize;
use tokio::{
runtime,
@@ -451,20 +452,22 @@ async fn handle_client_connection(
// Create an ordered reliable send channel for this client
let (to_client_sender, to_client_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
- let to_sync_server_clone = to_sync_server.clone();
- let close_sender_clone = client_close_sender.clone();
- let connection_clone = connection.clone();
- tokio::spawn(async move {
- client_sender_task(
- client_id,
- connection_clone,
- to_client_receiver,
- client_close_receiver,
- close_sender_clone,
- to_sync_server_clone,
- )
- .await
- });
+ let _client_sender = {
+ let to_sync_server_clone = to_sync_server.clone();
+ let close_sender_clone = client_close_sender.clone();
+ let connection_clone = connection.clone();
+ tokio::spawn(async move {
+ client_sender_task(
+ client_id,
+ connection_clone,
+ to_client_receiver,
+ client_close_receiver,
+ close_sender_clone,
+ to_sync_server_clone,
+ )
+ .await
+ })
+ };
// Signal the sync server of this new connection
to_sync_server
@@ -483,6 +486,21 @@ async fn handle_client_connection(
}
}
+ let _client_close_wait = {
+ let conn = connection.clone();
+ let close_sender = client_close_sender.clone();
+ let to_sync_server = to_sync_server.clone();
+ tokio::spawn(async move {
+ let conn_err = conn.closed().await;
+ info!("Client {} disconnected: {}", client_id, conn_err);
+ close_sender.send(()).ok();
+ to_sync_server
+ .send(InternalAsyncMessage::ClientLostConnection(client_id))
+ .await
+ .expect("Failed to signal connection lost to sync server");
+ });
+ };
+
// Spawn a task to listen for streams opened by this client
let _client_receiver = tokio::spawn(async move {
client_receiver_task(
@@ -519,22 +537,61 @@ async fn client_sender_task(
}
_ = async {
while let Some(msg_bytes) = to_client_receiver.recv().await {
- // TODO Perf: Batch frames for a send_all
- // TODO Clean: Error handling
- if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
- error!("Error while sending to client {}: {}", client_id, err);
- error!("Client {} seems disconnected, closing resources", client_id);
- if let Err(_) = close_sender.send(()) {
- error!("Failed to close all client streams & resources for client {}", client_id)
- }
- to_sync_server.send(
- InternalAsyncMessage::ClientLostConnection(client_id))
- .await
- .expect("Failed to signal connection lost to sync server");
- };
+ send_msg(
+ client_id,
+ &close_sender,
+ &to_sync_server,
+ &mut framed_send_stream,
+ msg_bytes,
+ )
+ .await
}
} => {}
}
+ while let Ok(msg_bytes) = to_client_receiver.try_recv() {
+ if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
+ error!("Error while sending to client {}: {}", client_id, err);
+ };
+ }
+ if let Err(err) = framed_send_stream.flush().await {
+ error!(
+ "Error while flushing stream to client {}: {}",
+ client_id, err
+ );
+ }
+ if let Err(err) = framed_send_stream.into_inner().finish().await {
+ error!(
+ "Failed to shutdown stream gracefully for client {}: {}",
+ client_id, err
+ );
+ }
+ connection.close(VarInt::from_u32(0), "closed".as_bytes());
+}
+
+async fn send_msg(
+ client_id: ClientId,
+ close_sender: &tokio::sync::broadcast::Sender<()>,
+ to_sync_server: &mpsc::Sender<InternalAsyncMessage>,
+ framed_send_stream: &mut FramedWrite<SendStream, LengthDelimitedCodec>,
+ msg_bytes: Bytes,
+) {
+ // TODO Perf: Batch frames for a send_all
+ if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
+ error!("Error while sending to client {}: {}", client_id, err);
+ error!("Client {} seems disconnected, closing resources", client_id);
+ // Emit ClientLostConnection to properly update the server about this client state.
+ // Raise ClientLostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed.
+ to_sync_server
+ .send(InternalAsyncMessage::ClientLostConnection(client_id))
+ .await
+ .expect("Failed to signal connection lost to sync server");
+ if let Err(_) = close_sender.send(()) {
+ error!(
+ "Failed to close all client streams & resources for client {}",
+ client_id
+ )
+ }
+ };
}
async fn client_receiver_task(
@@ -605,8 +662,12 @@ fn update_sync_server(
connection_events.send(ConnectionEvent { id: id });
}
InternalAsyncMessage::ClientLostConnection(client_id) => {
- endpoint.clients.remove(&client_id);
- connection_lost_events.send(ConnectionLostEvent { id: client_id });
+ match endpoint.clients.remove(&client_id) {
+ Some(_) => {
+ connection_lost_events.send(ConnectionLostEvent { id: client_id })
+ }
+ None => (),
+ }
}
}
}
diff --git a/src/shared.rs b/src/shared.rs
index 247841f..608bb0c 100644
--- a/src/shared.rs
+++ b/src/shared.rs
@@ -25,6 +25,8 @@ pub enum QuinnetError {
UnknownClient(ClientId),
#[error("Connection with id `{0}` is unknown")]
UnknownConnection(ConnectionId),
+ #[error("Connection is closed")]
+ ConnectionClosed,
#[error("Endpoint is already closed")]
EndpointAlreadyClosed,
#[error("Failed serialization")]