diff options
author | gilles henaux <gill.henaux@gmail.com> | 2023-01-12 15:43:20 +0100 |
---|---|---|
committer | gilles henaux <gill.henaux@gmail.com> | 2023-01-12 15:43:20 +0100 |
commit | d115261d0d2a9feaace1d5ba4e9e2fc9819994cb (patch) | |
tree | bfb1c930e9ed0abd375cdd60e24d93c727ca4195 /src | |
parent | 3d585a6fc850d77afc99385685d1f30626c0b6d0 (diff) |
[channels] Close the connection once all channel tasks have terminated
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 19 | ||||
-rw-r--r-- | src/shared/channel.rs | 17 |
2 files changed, 29 insertions, 7 deletions
diff --git a/src/client.rs b/src/client.rs index c3d2dfb..3e23d4a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use std::{ collections::{ hash_map::{Iter, IterMut}, - HashMap, + HashMap, HashSet, }, error::Error, net::SocketAddr, @@ -11,7 +11,7 @@ use std::{ use bevy::prelude::*; use bytes::Bytes; use futures_util::StreamExt; -use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint}; +use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, VarInt}; use quinn_proto::ConnectionStats; use serde::Deserialize; use tokio::{ @@ -665,6 +665,9 @@ async fn handle_connection_channels( mut from_sync_client: mpsc::Receiver<InternalSyncMessage>, to_sync_client: mpsc::Sender<InternalAsyncMessage>, ) { + // Use an mpsc channel where, instead of sending messages, we wait for the channel to be closed, which happens when every sender has been dropped. We can't use a JoinSet as simply here since we would also need to drain closed channels from it. + let (channel_tasks_keepalive, mut channel_tasks_waiter) = mpsc::channel(1); + let close_receiver_clone = close_receiver.resubscribe(); tokio::select! { _ = close_receiver.recv() => { @@ -677,11 +680,14 @@ async fn handle_connection_channels( let close_receiver = close_receiver_clone.resubscribe(); let connection_handle = connection.clone(); let to_sync_client = to_sync_client.clone(); + let channels_keepalive_clone = channel_tasks_keepalive.clone(); + match channel_id { ChannelId::OrderedReliable(_) => { tokio::spawn(async move { ordered_reliable_channel_task( connection_handle, + channels_keepalive_clone, to_sync_client, || InternalAsyncMessage::LostConnection, close_receiver, @@ -695,6 +701,7 @@ async fn handle_connection_channels( tokio::spawn(async move { unordered_reliable_channel_task( connection_handle, + channels_keepalive_clone, to_sync_client, || InternalAsyncMessage::LostConnection, close_receiver, @@ -711,6 +718,14 @@ async fn handle_connection_channels( trace!("Connection Channels listener ended") } }; + + // Wait for all the channels to have flushed/finished: + // We drop our sender first because the recv() call otherwise sleeps forever. + // When every sender has gone out of scope, the recv call will return with an error. We ignore the error. + drop(channel_tasks_keepalive); + let _ = channel_tasks_waiter.recv().await; + + connection.close(VarInt::from_u32(0), "closed".as_bytes()); } // Receive messages from the async client tasks and update the sync client. diff --git a/src/shared/channel.rs b/src/shared/channel.rs index 96d8afc..72dca41 100644 --- a/src/shared/channel.rs +++ b/src/shared/channel.rs @@ -2,7 +2,7 @@ use super::QuinnetError; use bevy::prelude::{error, trace}; use bytes::Bytes; use futures::sink::SinkExt; -use quinn::{SendStream, VarInt}; +use quinn::SendStream; use std::fmt::Debug; use tokio::sync::{ broadcast, @@ -69,6 +69,7 @@ impl Channel { pub(crate) async fn ordered_reliable_channel_task<T: Debug>( connection: quinn::Connection, + _: mpsc::Sender<()>, to_sync_client: mpsc::Sender<T>, on_lost_connection: fn() -> T, mut close_receiver: broadcast::Receiver<()>, @@ -97,7 +98,7 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>( } => { trace!("Ordered Reliable Channel task ended") } - } + }; while let Ok(msg_bytes) = to_server_receiver.try_recv() { if let Err(err) = frame_sender.send(msg_bytes).await { error!("Error while sending, {}", err); @@ -109,12 +110,11 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>( if let Err(err) = frame_sender.into_inner().finish().await { error!("Failed to shutdown stream gracefully: {}", err); } - todo!("Do not close here, wait for all channels to be flushed"); - connection.close(VarInt::from_u32(0), "closed".as_bytes()); } pub(crate) async fn unordered_reliable_channel_task<T: Debug>( connection: quinn::Connection, + _: mpsc::Sender<()>, to_sync_client: mpsc::Sender<T>, on_lost_connection: fn() -> T, mut close_receiver: broadcast::Receiver<()>, @@ -138,12 +138,19 @@ pub(crate) async fn unordered_reliable_channel_task<T: Debug>( .await .expect("Failed to signal connection lost to sync client"); } + todo!("finish the stream") } } => { trace!("Unordered Reliable Channel task ended") } + }; + while let Ok(msg_bytes) = to_server_receiver.try_recv() { + let mut frame_sender = new_uni_frame_sender(&connection).await; + if let Err(err) = frame_sender.send(msg_bytes).await { + error!("Error while sending, {}", err); + } + todo!("finish the stream") } - todo!("Flush and signal finished") } async fn new_uni_frame_sender( |