aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgilles henaux <gill.henaux@gmail.com>2023-01-12 15:43:20 +0100
committergilles henaux <gill.henaux@gmail.com>2023-01-12 15:43:20 +0100
commitd115261d0d2a9feaace1d5ba4e9e2fc9819994cb (patch)
treebfb1c930e9ed0abd375cdd60e24d93c727ca4195 /src
parent3d585a6fc850d77afc99385685d1f30626c0b6d0 (diff)
[channels] Close the connection once all channel tasks have terminated
Diffstat (limited to 'src')
-rw-r--r--src/client.rs19
-rw-r--r--src/shared/channel.rs17
2 files changed, 29 insertions, 7 deletions
diff --git a/src/client.rs b/src/client.rs
index c3d2dfb..3e23d4a 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,7 +1,7 @@
use std::{
collections::{
hash_map::{Iter, IterMut},
- HashMap,
+ HashMap, HashSet,
},
error::Error,
net::SocketAddr,
@@ -11,7 +11,7 @@ use std::{
use bevy::prelude::*;
use bytes::Bytes;
use futures_util::StreamExt;
-use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint};
+use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, VarInt};
use quinn_proto::ConnectionStats;
use serde::Deserialize;
use tokio::{
@@ -665,6 +665,9 @@ async fn handle_connection_channels(
mut from_sync_client: mpsc::Receiver<InternalSyncMessage>,
to_sync_client: mpsc::Sender<InternalAsyncMessage>,
) {
+ // Use an mpsc channel where, instead of sending messages, we wait for the channel to be closed, which happens when every sender has been dropped. We can't use a JoinSet as simply here since we would also need to drain closed channels from it.
+ let (channel_tasks_keepalive, mut channel_tasks_waiter) = mpsc::channel(1);
+
let close_receiver_clone = close_receiver.resubscribe();
tokio::select! {
_ = close_receiver.recv() => {
@@ -677,11 +680,14 @@ async fn handle_connection_channels(
let close_receiver = close_receiver_clone.resubscribe();
let connection_handle = connection.clone();
let to_sync_client = to_sync_client.clone();
+ let channels_keepalive_clone = channel_tasks_keepalive.clone();
+
match channel_id {
ChannelId::OrderedReliable(_) => {
tokio::spawn(async move {
ordered_reliable_channel_task(
connection_handle,
+ channels_keepalive_clone,
to_sync_client,
|| InternalAsyncMessage::LostConnection,
close_receiver,
@@ -695,6 +701,7 @@ async fn handle_connection_channels(
tokio::spawn(async move {
unordered_reliable_channel_task(
connection_handle,
+ channels_keepalive_clone,
to_sync_client,
|| InternalAsyncMessage::LostConnection,
close_receiver,
@@ -711,6 +718,14 @@ async fn handle_connection_channels(
trace!("Connection Channels listener ended")
}
};
+
+ // Wait for all the channels to have flushed/finished:
+ // We drop our sender first because the recv() call otherwise sleeps forever.
+ // When every sender has gone out of scope, the recv call will return with an error. We ignore the error.
+ drop(channel_tasks_keepalive);
+ let _ = channel_tasks_waiter.recv().await;
+
+ connection.close(VarInt::from_u32(0), "closed".as_bytes());
}
// Receive messages from the async client tasks and update the sync client.
diff --git a/src/shared/channel.rs b/src/shared/channel.rs
index 96d8afc..72dca41 100644
--- a/src/shared/channel.rs
+++ b/src/shared/channel.rs
@@ -2,7 +2,7 @@ use super::QuinnetError;
use bevy::prelude::{error, trace};
use bytes::Bytes;
use futures::sink::SinkExt;
-use quinn::{SendStream, VarInt};
+use quinn::SendStream;
use std::fmt::Debug;
use tokio::sync::{
broadcast,
@@ -69,6 +69,7 @@ impl Channel {
pub(crate) async fn ordered_reliable_channel_task<T: Debug>(
connection: quinn::Connection,
+ _: mpsc::Sender<()>,
to_sync_client: mpsc::Sender<T>,
on_lost_connection: fn() -> T,
mut close_receiver: broadcast::Receiver<()>,
@@ -97,7 +98,7 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>(
} => {
trace!("Ordered Reliable Channel task ended")
}
- }
+ };
while let Ok(msg_bytes) = to_server_receiver.try_recv() {
if let Err(err) = frame_sender.send(msg_bytes).await {
error!("Error while sending, {}", err);
@@ -109,12 +110,11 @@ pub(crate) async fn ordered_reliable_channel_task<T: Debug>(
if let Err(err) = frame_sender.into_inner().finish().await {
error!("Failed to shutdown stream gracefully: {}", err);
}
- todo!("Do not close here, wait for all channels to be flushed");
- connection.close(VarInt::from_u32(0), "closed".as_bytes());
}
pub(crate) async fn unordered_reliable_channel_task<T: Debug>(
connection: quinn::Connection,
+ _: mpsc::Sender<()>,
to_sync_client: mpsc::Sender<T>,
on_lost_connection: fn() -> T,
mut close_receiver: broadcast::Receiver<()>,
@@ -138,12 +138,19 @@ pub(crate) async fn unordered_reliable_channel_task<T: Debug>(
.await
.expect("Failed to signal connection lost to sync client");
}
+ todo!("finish the stream")
}
} => {
trace!("Unordered Reliable Channel task ended")
}
+ };
+ while let Ok(msg_bytes) = to_server_receiver.try_recv() {
+ let mut frame_sender = new_uni_frame_sender(&connection).await;
+ if let Err(err) = frame_sender.send(msg_bytes).await {
+ error!("Error while sending, {}", err);
+ }
+ todo!("finish the stream")
}
- todo!("Flush and signal finished")
}
async fn new_uni_frame_sender(