aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenauxg <19689618+Henauxg@users.noreply.github.com>2022-11-18 15:49:46 +0100
committerHenauxg <19689618+Henauxg@users.noreply.github.com>2022-11-18 15:49:46 +0100
commitf3c9c4bb7e3a32e69df9b878145fbdd08a3655d4 (patch)
treed20d7163b83b65143b1c5721bbe5ffcf566971ef
parenta98498641df82547c9c05b35bab33377a356fe64 (diff)
[server] Rework : add Endpoint
-rw-r--r--src/lib.rs15
-rw-r--r--src/server.rs430
-rw-r--r--src/server/certificate.rs114
3 files changed, 278 insertions, 281 deletions
diff --git a/src/lib.rs b/src/lib.rs
index e25c68a..813a344 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,7 +1,3 @@
-// pub const DEFAULT_MESSAGE_QUEUE_SIZE: usize = 150;
-// pub const DEFAULT_KILL_MESSAGE_QUEUE_SIZE: usize = 10;
-// pub const DEFAULT_KEEP_ALIVE_INTERVAL_S: u64 = 4;
-
pub mod client;
pub mod server;
pub mod shared;
@@ -127,6 +123,7 @@ mod tests {
let (client_message, client_id) = server_app
.world
.resource_mut::<Server>()
+ .endpoint_mut()
.receive_message::<SharedMessage>()
.expect("Failed to receive client message")
.expect("There should be a client message");
@@ -144,6 +141,7 @@ mod tests {
{
let server = server_app.world.resource::<Server>();
server
+ .endpoint()
.broadcast_message(sent_server_message.clone())
.unwrap();
}
@@ -200,7 +198,7 @@ mod tests {
{
let mut server = server_app.world.resource_mut::<Server>();
server
- .start(
+ .open_endpoint(
ServerConfigurationData::new(
SERVER_HOST.to_string(),
SERVER_PORT,
@@ -308,12 +306,11 @@ mod tests {
}
// Server reboots, and generates a new self-signed certificate
- // TODO Close server endpoint here
-
{
let mut server = server_app.world.resource_mut::<Server>();
+ server.close_endpoint();
server
- .start(
+ .open_endpoint(
ServerConfigurationData::new(
SERVER_HOST.to_string(),
SERVER_PORT,
@@ -451,7 +448,7 @@ mod tests {
fn start_listening(mut server: ResMut<Server>) {
server
- .start(
+ .open_endpoint(
ServerConfigurationData::new(
SERVER_HOST.to_string(),
SERVER_PORT,
diff --git a/src/server.rs b/src/server.rs
index 029ab7b..f8ec3f6 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -4,9 +4,10 @@ use bevy::prelude::*;
use bytes::Bytes;
use futures::sink::SinkExt;
use futures_util::StreamExt;
-use quinn::{Endpoint, ServerConfig};
+use quinn::{Endpoint as QuinnEndpoint, ServerConfig};
use serde::Deserialize;
use tokio::{
+ runtime,
sync::{
broadcast::{self},
mpsc::{self, error::TryRecvError},
@@ -23,7 +24,7 @@ use crate::{
},
};
-use self::certificate::{CertificateRetrievalMode, CertificateRetrievedEvent};
+use self::certificate::{CertificateRetrievalMode, CertificateRetrievedEvent, ServerCertificate};
pub mod certificate;
@@ -85,13 +86,6 @@ pub struct ClientPayload {
msg: Bytes,
}
-/// Current state of the client driver
-#[derive(Debug, PartialEq, Eq)]
-enum ServerState {
- Idle,
- Listening,
-}
-
#[derive(Debug)]
pub(crate) enum InternalAsyncMessage {
ClientConnected(ClientConnection),
@@ -101,10 +95,6 @@ pub(crate) enum InternalAsyncMessage {
#[derive(Debug, Clone)]
pub(crate) enum InternalSyncMessage {
- StartListening {
- config: ServerConfigurationData,
- cert_mode: CertificateRetrievalMode,
- },
ClientConnectedAck(ClientId),
}
@@ -115,43 +105,15 @@ pub(crate) struct ClientConnection {
close_sender: broadcast::Sender<()>,
}
-#[derive(Resource)]
-pub struct Server {
+pub struct Endpoint {
clients: HashMap<ClientId, ClientConnection>,
receiver: mpsc::Receiver<ClientPayload>,
- state: ServerState,
pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>,
pub(crate) internal_sender: broadcast::Sender<InternalSyncMessage>,
}
-impl Server {
- /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode]
- pub fn start(
- &mut self,
- config: ServerConfigurationData,
- cert_mode: CertificateRetrievalMode,
- ) -> Result<(), QuinnetError> {
- match self
- .internal_sender
- .send(InternalSyncMessage::StartListening { config, cert_mode })
- {
- Ok(_) => {
- self.state = ServerState::Listening;
- Ok(())
- }
- Err(_) => Err(QuinnetError::FullQueue),
- }
- }
-
- /// Returns true if the server is currently listening for messages and connections.
- pub fn is_listening(&self) -> bool {
- match self.state {
- ServerState::Idle => false,
- ServerState::Listening => true,
- }
- }
-
+impl Endpoint {
pub fn disconnect_client(&mut self, client_id: ClientId) -> Result<(), QuinnetError> {
match self.clients.remove(&client_id) {
Some(client_connection) => match client_connection.close_sender.send(()) {
@@ -186,7 +148,7 @@ impl Server {
}
pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>(
- &mut self,
+ &self,
client_ids: I,
message: T,
) -> Result<(), QuinnetError> {
@@ -227,7 +189,7 @@ impl Server {
}
pub fn send_payload<T: Into<Bytes>>(
- &mut self,
+ &self,
client_id: ClientId,
payload: T,
) -> Result<(), QuinnetError> {
@@ -255,43 +217,103 @@ impl Server {
}
}
-async fn connections_listening_task(
- config: ServerConfigurationData,
- cert_mode: CertificateRetrievalMode,
+#[derive(Resource)]
+pub struct Server {
+ runtime: runtime::Handle,
+ endpoint: Option<Endpoint>,
+}
+
+impl Server {
+ pub fn endpoint(&self) -> &Endpoint {
+ self.endpoint.as_ref().unwrap()
+ }
+
+ pub fn endpoint_mut(&mut self) -> &mut Endpoint {
+ self.endpoint.as_mut().unwrap()
+ }
+
+ pub fn get_endpoint(&self) -> Option<&Endpoint> {
+ self.endpoint.as_ref()
+ }
+
+ pub fn get_endpoint_mut(&mut self) -> Option<&mut Endpoint> {
+ self.endpoint.as_mut()
+ }
+
+ /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode]
+ pub fn open_endpoint(
+ &mut self,
+ config: ServerConfigurationData,
+ cert_mode: CertificateRetrievalMode,
+ ) -> Result<ServerCertificate, QuinnetError> {
+ let server_adr_str = format!("{}:{}", config.local_bind_host, config.port);
+ let server_addr = server_adr_str.parse::<SocketAddr>()?;
+
+ // Endpoint configuration
+ let server_cert = retrieve_certificate(&config.host, cert_mode)?;
+ let mut server_config = ServerConfig::with_single_cert(
+ server_cert.cert_chain.clone(),
+ server_cert.priv_key.clone(),
+ )?;
+ Arc::get_mut(&mut server_config.transport)
+ .ok_or(QuinnetError::LockAcquisitionFailure)?
+ .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
+
+ let (from_clients_sender, from_clients_receiver) =
+ mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
+ let (to_sync_server, from_async_server) =
+ mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+ let (to_async_server, from_sync_server) =
+ broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
+
+ info!("Starting endpoint on: {} ...", server_adr_str);
+
+ self.runtime.spawn(async move {
+ endpoint_task(
+ server_config,
+ server_addr,
+ to_sync_server.clone(),
+ from_sync_server,
+ from_clients_sender.clone(),
+ )
+ .await;
+ });
+
+ self.endpoint = Some(Endpoint {
+ clients: HashMap::new(),
+ receiver: from_clients_receiver,
+ internal_receiver: from_async_server,
+ internal_sender: to_async_server.clone(),
+ });
+
+ Ok(server_cert)
+ }
+
+ pub fn close_endpoint(&mut self) {
+ self.endpoint = None;
+ }
+
+ /// Returns true if the server is currently listening for messages and connections.
+ pub fn is_listening(&self) -> bool {
+ match &self.endpoint {
+ Some(_) => true,
+ None => false,
+ }
+ }
+}
+
+async fn endpoint_task(
+ endpoint_config: ServerConfig,
+ endpoint_adr: SocketAddr,
to_sync_server: mpsc::Sender<InternalAsyncMessage>,
mut from_sync_server: broadcast::Receiver<InternalSyncMessage>,
from_clients_sender: mpsc::Sender<ClientPayload>,
) {
- // TODO Fix: handle unwraps
- let server_adr_str = format!("{}:{}", config.local_bind_host, config.port);
- info!("Starting server on: {} ...", server_adr_str);
-
- let server_addr: SocketAddr = server_adr_str
- .parse()
- .expect("Failed to parse server address");
-
- // Endpoint configuration
- let (cert_chain, priv_key, cert_fingerprint) =
- retrieve_certificate(&config.host, cert_mode).expect("Failed to retrieve certificate");
- to_sync_server
- .send(InternalAsyncMessage::CertificateRetrieved(
- CertificateRetrievedEvent {
- fingerprint: cert_fingerprint,
- },
- ))
- .await
- .expect("Failed to signal certificate retrieval to sync server");
-
- let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key).unwrap();
- Arc::get_mut(&mut server_config.transport)
- .unwrap()
- .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S)));
-
let mut client_gen_id: ClientId = 0;
let mut client_id_mappings = HashMap::new();
- let endpoint =
- Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint");
+ let endpoint = QuinnEndpoint::server(endpoint_config, endpoint_adr)
+ .expect("Failed to create the endpoint");
// Start iterating over incoming connections/clients.
while let Some(connecting) = endpoint.accept().await {
@@ -312,50 +334,28 @@ async fn connections_listening_task(
);
// Create a close channel for this client
- let (close_sender, mut close_receiver): (
+ let (close_sender, close_receiver): (
tokio::sync::broadcast::Sender<()>,
tokio::sync::broadcast::Receiver<()>,
) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
// Create an ordered reliable send channel for this client
- let (to_client_sender, mut to_client_receiver) =
+ let (to_client_sender, to_client_receiver) =
mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE);
- let send_stream = connection.open_uni().await.expect(
- format!(
- "Failed to open unidirectional send stream for client: {}",
- client_id
- )
- .as_str(),
- );
-
- let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
-
let to_sync_server_clone = to_sync_server.clone();
let close_sender_clone = close_sender.clone();
- let _network_broadcaster = tokio::spawn(async move {
- tokio::select! {
- _ = close_receiver.recv() => {
- trace!("Unidirectional send stream forced to disconnected for client: {}", client_id)
- }
- _ = async {
- while let Some(msg_bytes) = to_client_receiver.recv().await {
- // TODO Perf: Batch frames for a send_all
- // TODO Clean: Error handling
- if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
- error!("Error while sending to client {}: {}", client_id, err);
- error!("Client {} seems disconnected, closing resources", client_id);
- if let Err(_) = close_sender_clone.send(()) {
- error!("Failed to close all client streams & resources for client {}", client_id)
- }
- to_sync_server_clone.send(
- InternalAsyncMessage::ClientLostConnection(client_id))
- .await
- .expect("Failed to signal connection lost to sync server");
- };
- }
- } => {}
- }
+ let connection_clone = connection.clone();
+ tokio::spawn(async move {
+ client_sender_task(
+ client_id,
+ connection_clone,
+ to_client_receiver,
+ close_receiver,
+ close_sender_clone,
+ to_sync_server_clone,
+ )
+ .await
});
// Signal the sync server of this new connection
@@ -367,97 +367,115 @@ async fn connections_listening_task(
}))
.await
.expect("Failed to signal connection to sync client");
-
// Wait for the sync server to acknowledge the connection before spawning reception tasks.
- while let Ok(message) = from_sync_server.recv().await {
- match message {
- InternalSyncMessage::ClientConnectedAck(id) => {
- if id == client_id {
- break;
- }
- }
- _ => {}
+ while let Ok(InternalSyncMessage::ClientConnectedAck(id)) = from_sync_server.recv().await {
+ if id == client_id {
+ break;
}
}
- // Spawn a task to listen for stream opened from this client
+ // Spawn a task to listen for streams opened by this client
let from_client_sender_clone = from_clients_sender.clone();
- let mut uni_receivers: JoinSet<()> = JoinSet::new();
- let mut close_receiver = close_sender.subscribe();
let _client_receiver = tokio::spawn(async move {
- tokio::select! {
- _ = close_receiver.recv() => {
- trace!("New Stream listener forced to disconnected for client: {}", client_id)
- }
- _ = async {
- // For each new stream opened by the client
- while let Ok(recv) = connection.accept_uni().await {
- let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
-
- // Spawn a task to receive data on this stream.
- let from_client_sender = from_client_sender_clone.clone();
- uni_receivers.spawn(async move {
- while let Some(Ok(msg_bytes)) = frame_recv.next().await {
- from_client_sender
- .send(ClientPayload {
- client_id: client_id,
- msg: msg_bytes.into(),
- })
- .await
- .unwrap();// TODO Fix: error event
- }
- trace!("Unidirectional stream receiver ended for client: {}", client_id)
- });
- }
- } => {
- trace!("New Stream listener ended for client: {}", client_id)
- }
- }
- uni_receivers.shutdown().await;
- trace!(
- "All unidirectional stream receivers cleaned for client: {}",
- client_id
+ client_receiver_task(
+ client_id,
+ connection,
+ close_sender.subscribe(),
+ from_client_sender_clone,
)
+ .await
});
}
}
-fn start_async_server(mut commands: Commands, runtime: Res<AsyncRuntime>) {
- // TODO Clean: Configure size
- let (from_clients_sender, from_clients_receiver) =
- mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE);
-
- let (to_sync_server, from_async_server) =
- mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
- let (to_async_server, mut from_sync_server) =
- broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE);
-
- commands.insert_resource(Server {
- clients: HashMap::new(),
- receiver: from_clients_receiver,
- state: ServerState::Idle,
- internal_receiver: from_async_server,
- internal_sender: to_async_server.clone(),
- });
+async fn client_sender_task(
+ client_id: ClientId,
+ connection: quinn::Connection,
+ mut to_client_receiver: tokio::sync::mpsc::Receiver<Bytes>,
+ mut close_receiver: tokio::sync::broadcast::Receiver<()>,
+ close_sender: tokio::sync::broadcast::Sender<()>,
+ to_sync_server: mpsc::Sender<InternalAsyncMessage>,
+) {
+ let send_stream = connection.open_uni().await.expect(
+ format!(
+ "Failed to open unidirectional send stream for client: {}",
+ client_id
+ )
+ .as_str(),
+ );
+
+ let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new());
+
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("Unidirectional send stream forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ while let Some(msg_bytes) = to_client_receiver.recv().await {
+ // TODO Perf: Batch frames for a send_all
+ // TODO Clean: Error handling
+ if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await {
+ error!("Error while sending to client {}: {}", client_id, err);
+ error!("Client {} seems disconnected, closing resources", client_id);
+ if let Err(_) = close_sender.send(()) {
+ error!("Failed to close all client streams & resources for client {}", client_id)
+ }
+ to_sync_server.send(
+ InternalAsyncMessage::ClientLostConnection(client_id))
+ .await
+ .expect("Failed to signal connection lost to sync server");
+ };
+ }
+ } => {}
+ }
+}
- // Create async server task
- runtime.spawn(async move {
- // Wait for the sync server to to signal us to start listening
- while let Ok(message) = from_sync_server.recv().await {
- match message {
- InternalSyncMessage::StartListening { config, cert_mode } => {
- connections_listening_task(
- config,
- cert_mode,
- to_sync_server.clone(),
- to_async_server.subscribe(),
- from_clients_sender.clone(),
- )
- .await;
- }
- _ => {}
+async fn client_receiver_task(
+ client_id: ClientId,
+ connection: quinn::Connection,
+ mut close_receiver: tokio::sync::broadcast::Receiver<()>,
+ from_clients_sender: mpsc::Sender<ClientPayload>,
+) {
+ let mut uni_receivers: JoinSet<()> = JoinSet::new();
+ tokio::select! {
+ _ = close_receiver.recv() => {
+ trace!("New Stream listener forced to disconnected for client: {}", client_id)
+ }
+ _ = async {
+ // For each new stream opened by the client
+ while let Ok(recv) = connection.accept_uni().await {
+ let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
+
+ // Spawn a task to receive data on this stream.
+ let from_client_sender = from_clients_sender.clone();
+ uni_receivers.spawn(async move {
+ while let Some(Ok(msg_bytes)) = frame_recv.next().await {
+ from_client_sender
+ .send(ClientPayload {
+ client_id: client_id,
+ msg: msg_bytes.into(),
+ })
+ .await
+ .unwrap();// TODO Fix: error event
+ }
+ trace!("Unidirectional stream receiver ended for client: {}", client_id)
+ });
}
+ } => {
+ trace!("New Stream listener ended for client: {}", client_id)
}
+ }
+ uni_receivers.shutdown().await;
+ trace!(
+ "All unidirectional stream receivers cleaned for client: {}",
+ client_id
+ )
+}
+
+fn create_server(mut commands: Commands, runtime: Res<AsyncRuntime>) {
+ commands.insert_resource(Server {
+ endpoint: None,
+ runtime: runtime.handle().clone(),
});
}
@@ -468,23 +486,25 @@ fn update_sync_server(
mut connection_lost_events: EventWriter<ConnectionLostEvent>,
mut certificate_retrieved_events: EventWriter<CertificateRetrievedEvent>,
) {
- while let Ok(message) = server.internal_receiver.try_recv() {
- match message {
- InternalAsyncMessage::ClientConnected(connection) => {
- let id = connection.client_id;
- server.clients.insert(id, connection);
- server
- .internal_sender
- .send(InternalSyncMessage::ClientConnectedAck(id))
- .unwrap();
- connection_events.send(ConnectionEvent { id: id });
- }
- InternalAsyncMessage::ClientLostConnection(client_id) => {
- server.clients.remove(&client_id);
- connection_lost_events.send(ConnectionLostEvent { id: client_id });
- }
- InternalAsyncMessage::CertificateRetrieved(event) => {
- certificate_retrieved_events.send(event)
+ if let Some(endpoint) = server.get_endpoint_mut() {
+ while let Ok(message) = endpoint.internal_receiver.try_recv() {
+ match message {
+ InternalAsyncMessage::ClientConnected(connection) => {
+ let id = connection.client_id;
+ endpoint.clients.insert(id, connection);
+ endpoint
+ .internal_sender
+ .send(InternalSyncMessage::ClientConnectedAck(id))
+ .unwrap();
+ connection_events.send(ConnectionEvent { id: id });
+ }
+ InternalAsyncMessage::ClientLostConnection(client_id) => {
+ endpoint.clients.remove(&client_id);
+ connection_lost_events.send(ConnectionLostEvent { id: client_id });
+ }
+ InternalAsyncMessage::CertificateRetrieved(event) => {
+ certificate_retrieved_events.send(event)
+ }
}
}
}
@@ -503,7 +523,7 @@ impl Plugin for QuinnetServerPlugin {
app.add_event::<ConnectionEvent>()
.add_event::<ConnectionLostEvent>()
.add_event::<CertificateRetrievedEvent>()
- .add_startup_system_to_stage(StartupStage::PreStartup, start_async_server)
+ .add_startup_system_to_stage(StartupStage::PreStartup, create_server)
.add_system_to_stage(CoreStage::PreUpdate, update_sync_server);
if app.world.get_resource_mut::<AsyncRuntime>().is_none() {
diff --git a/src/server/certificate.rs b/src/server/certificate.rs
index 2dd207a..3cb9820 100644
--- a/src/server/certificate.rs
+++ b/src/server/certificate.rs
@@ -1,5 +1,4 @@
use std::{
- error::Error,
fs::{self, File},
io::BufReader,
path::Path,
@@ -7,7 +6,7 @@ use std::{
use bevy::prelude::{trace, warn};
-use crate::shared::CertificateFingerprint;
+use crate::shared::{CertificateFingerprint, QuinnetError};
/// Event raised when a certificate is retrieved on the server
#[derive(Debug, Clone)]
@@ -37,19 +36,18 @@ pub enum CertificateRetrievalMode {
},
}
+pub struct ServerCertificate {
+ pub cert_chain: Vec<rustls::Certificate>,
+ pub priv_key: rustls::PrivateKey,
+ pub fingerprint: CertificateFingerprint,
+}
+
fn read_certs_from_files(
cert_file: &String,
key_file: &String,
-) -> Result<
- (
- Vec<rustls::Certificate>,
- rustls::PrivateKey,
- CertificateFingerprint,
- ),
- Box<dyn Error>,
-> {
+) -> Result<ServerCertificate, QuinnetError> {
let mut cert_chain_reader = BufReader::new(File::open(cert_file)?);
- let certs: Vec<rustls::Certificate> = rustls_pemfile::certs(&mut cert_chain_reader)?
+ let cert_chain: Vec<rustls::Certificate> = rustls_pemfile::certs(&mut cert_chain_reader)?
.into_iter()
.map(rustls::Certificate)
.collect();
@@ -58,19 +56,23 @@ fn read_certs_from_files(
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)?;
assert_eq!(keys.len(), 1);
- let key = rustls::PrivateKey(keys.remove(0));
+ let priv_key = rustls::PrivateKey(keys.remove(0));
- assert!(certs.len() >= 1);
- let fingerprint = CertificateFingerprint::from(&certs[0]);
+ assert!(cert_chain.len() >= 1);
+ let fingerprint = CertificateFingerprint::from(&cert_chain[0]);
- Ok((certs, key, fingerprint))
+ Ok(ServerCertificate {
+ cert_chain,
+ priv_key,
+ fingerprint,
+ })
}
fn write_certs_to_files(
cert: &rcgen::Certificate,
cert_file: &String,
key_file: &String,
-) -> Result<(), Box<dyn Error>> {
+) -> Result<(), QuinnetError> {
let pem_cert = cert.serialize_pem()?;
let pem_key = cert.serialize_private_key_pem();
@@ -82,81 +84,59 @@ fn write_certs_to_files(
fn generate_self_signed_certificate(
server_host: &String,
-) -> Result<
- (
- Vec<rustls::Certificate>,
- rustls::PrivateKey,
- rcgen::Certificate,
- CertificateFingerprint,
- ),
- Box<dyn Error>,
-> {
- let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap();
- let cert_der = cert.serialize_der().unwrap();
+) -> Result<(ServerCertificate, rcgen::Certificate), QuinnetError> {
+ let cert = rcgen::generate_simple_self_signed(vec![server_host.into()])?;
+ let cert_der = cert.serialize_der()?;
let priv_key = rustls::PrivateKey(cert.serialize_private_key_der());
let rustls_cert = rustls::Certificate(cert_der.clone());
let fingerprint = CertificateFingerprint::from(&rustls_cert);
let cert_chain = vec![rustls_cert];
- Ok((cert_chain, priv_key, cert, fingerprint))
+ Ok((
+ ServerCertificate {
+ cert_chain,
+ priv_key,
+ fingerprint,
+ },
+ cert,
+ ))
}
pub(crate) fn retrieve_certificate(
server_host: &String,
cert_mode: CertificateRetrievalMode,
-) -> Result<
- (
- Vec<rustls::Certificate>,
- rustls::PrivateKey,
- CertificateFingerprint,
- ),
- Box<dyn Error>,
-> {
+) -> Result<ServerCertificate, QuinnetError> {
match cert_mode {
CertificateRetrievalMode::GenerateSelfSigned => {
- trace!("Generating a new self-signed certificate");
- match generate_self_signed_certificate(server_host) {
- Ok((cert_chain, priv_key, _rcgen_cert, fingerprint)) => {
- Ok((cert_chain, priv_key, fingerprint))
- }
- Err(e) => Err(e),
- }
+ let (server_cert, _rcgen_cert) = generate_self_signed_certificate(server_host)?;
+ trace!("Generatied a new self-signed certificate");
+ Ok(server_cert)
}
CertificateRetrievalMode::LoadFromFile {
cert_file,
key_file,
- } => match read_certs_from_files(&cert_file, &key_file) {
- Ok((cert_chain, priv_key, fingerprint)) => {
- trace!("Successfuly loaded cert and key from files");
- Ok((cert_chain, priv_key, fingerprint))
- }
- Err(e) => Err(e),
- },
+ } => {
+ let server_cert = read_certs_from_files(&cert_file, &key_file)?;
+ trace!("Successfuly loaded cert and key from files");
+ Ok(server_cert)
+ }
CertificateRetrievalMode::LoadFromFileOrGenerateSelfSigned {
save_on_disk,
cert_file,
key_file,
} => {
if Path::new(&cert_file).exists() && Path::new(&key_file).exists() {
- match read_certs_from_files(&cert_file, &key_file) {
- Ok((cert_chain, priv_key, fingerprint)) => {
- trace!("Successfuly loaded cert and key from files");
- Ok((cert_chain, priv_key, fingerprint))
- }
- Err(e) => Err(e),
- }
+ let server_cert = read_certs_from_files(&cert_file, &key_file)?;
+ trace!("Successfuly loaded cert and key from files");
+ Ok(server_cert)
} else {
- warn!("{} and/or {} do not exist, could not load existing certificate. Generating a self-signed one.", cert_file, key_file);
- match generate_self_signed_certificate(server_host) {
- Ok((cert_chain, priv_key, rcgen_cert, fingerprint)) => {
- if save_on_disk {
- write_certs_to_files(&rcgen_cert, &cert_file, &key_file)?;
- trace!("Successfuly saved cert and key to files");
- }
- Ok((cert_chain, priv_key, fingerprint))
- }
- Err(e) => Err(e),
+ warn!("{} and/or {} do not exist, could not load existing certificate. Generating a new self-signed certificate.", cert_file, key_file);
+ let (server_cert, rcgen_cert) = generate_self_signed_certificate(server_host)?;
+ if save_on_disk {
+ write_certs_to_files(&rcgen_cert, &cert_file, &key_file)?;
+ trace!("Successfuly saved cert and key to files");
}
+ Ok(server_cert)
}
}
}