diff options
author | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-11-18 15:49:46 +0100 |
---|---|---|
committer | Henauxg <19689618+Henauxg@users.noreply.github.com> | 2022-11-18 15:49:46 +0100 |
commit | f3c9c4bb7e3a32e69df9b878145fbdd08a3655d4 (patch) | |
tree | d20d7163b83b65143b1c5721bbe5ffcf566971ef | |
parent | a98498641df82547c9c05b35bab33377a356fe64 (diff) |
[server] Rework : add Endpoint
-rw-r--r-- | src/lib.rs | 15 | ||||
-rw-r--r-- | src/server.rs | 430 | ||||
-rw-r--r-- | src/server/certificate.rs | 114 |
3 files changed, 278 insertions, 281 deletions
@@ -1,7 +1,3 @@ -// pub const DEFAULT_MESSAGE_QUEUE_SIZE: usize = 150; -// pub const DEFAULT_KILL_MESSAGE_QUEUE_SIZE: usize = 10; -// pub const DEFAULT_KEEP_ALIVE_INTERVAL_S: u64 = 4; - pub mod client; pub mod server; pub mod shared; @@ -127,6 +123,7 @@ mod tests { let (client_message, client_id) = server_app .world .resource_mut::<Server>() + .endpoint_mut() .receive_message::<SharedMessage>() .expect("Failed to receive client message") .expect("There should be a client message"); @@ -144,6 +141,7 @@ mod tests { { let server = server_app.world.resource::<Server>(); server + .endpoint() .broadcast_message(sent_server_message.clone()) .unwrap(); } @@ -200,7 +198,7 @@ mod tests { { let mut server = server_app.world.resource_mut::<Server>(); server - .start( + .open_endpoint( ServerConfigurationData::new( SERVER_HOST.to_string(), SERVER_PORT, @@ -308,12 +306,11 @@ mod tests { } // Server reboots, and generates a new self-signed certificate - // TODO Close server endpoint here - { let mut server = server_app.world.resource_mut::<Server>(); + server.close_endpoint(); server - .start( + .open_endpoint( ServerConfigurationData::new( SERVER_HOST.to_string(), SERVER_PORT, @@ -451,7 +448,7 @@ mod tests { fn start_listening(mut server: ResMut<Server>) { server - .start( + .open_endpoint( ServerConfigurationData::new( SERVER_HOST.to_string(), SERVER_PORT, diff --git a/src/server.rs b/src/server.rs index 029ab7b..f8ec3f6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,9 +4,10 @@ use bevy::prelude::*; use bytes::Bytes; use futures::sink::SinkExt; use futures_util::StreamExt; -use quinn::{Endpoint, ServerConfig}; +use quinn::{Endpoint as QuinnEndpoint, ServerConfig}; use serde::Deserialize; use tokio::{ + runtime, sync::{ broadcast::{self}, mpsc::{self, error::TryRecvError}, @@ -23,7 +24,7 @@ use crate::{ }, }; -use self::certificate::{CertificateRetrievalMode, CertificateRetrievedEvent}; +use self::certificate::{CertificateRetrievalMode, CertificateRetrievedEvent, ServerCertificate}; pub mod certificate; @@ -85,13 +86,6 @@ pub struct ClientPayload { msg: Bytes, } -/// Current state of the client driver -#[derive(Debug, PartialEq, Eq)] -enum ServerState { - Idle, - Listening, -} - #[derive(Debug)] pub(crate) enum InternalAsyncMessage { ClientConnected(ClientConnection), @@ -101,10 +95,6 @@ pub(crate) enum InternalAsyncMessage { #[derive(Debug, Clone)] pub(crate) enum InternalSyncMessage { - StartListening { - config: ServerConfigurationData, - cert_mode: CertificateRetrievalMode, - }, ClientConnectedAck(ClientId), } @@ -115,43 +105,15 @@ pub(crate) struct ClientConnection { close_sender: broadcast::Sender<()>, } -#[derive(Resource)] -pub struct Server { +pub struct Endpoint { clients: HashMap<ClientId, ClientConnection>, receiver: mpsc::Receiver<ClientPayload>, - state: ServerState, pub(crate) internal_receiver: mpsc::Receiver<InternalAsyncMessage>, pub(crate) internal_sender: broadcast::Sender<InternalSyncMessage>, } -impl Server { - /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode] - pub fn start( - &mut self, - config: ServerConfigurationData, - cert_mode: CertificateRetrievalMode, - ) -> Result<(), QuinnetError> { - match self - .internal_sender - .send(InternalSyncMessage::StartListening { config, cert_mode }) - { - Ok(_) => { - self.state = ServerState::Listening; - Ok(()) - } - Err(_) => Err(QuinnetError::FullQueue), - } - } - - /// Returns true if the server is currently listening for messages and connections. - pub fn is_listening(&self) -> bool { - match self.state { - ServerState::Idle => false, - ServerState::Listening => true, - } - } - +impl Endpoint { pub fn disconnect_client(&mut self, client_id: ClientId) -> Result<(), QuinnetError> { match self.clients.remove(&client_id) { Some(client_connection) => match client_connection.close_sender.send(()) { @@ -186,7 +148,7 @@ impl Server { } pub fn send_group_message<'a, I: Iterator<Item = &'a ClientId>, T: serde::Serialize>( - &mut self, + &self, client_ids: I, message: T, ) -> Result<(), QuinnetError> { @@ -227,7 +189,7 @@ impl Server { } pub fn send_payload<T: Into<Bytes>>( - &mut self, + &self, client_id: ClientId, payload: T, ) -> Result<(), QuinnetError> { @@ -255,43 +217,103 @@ impl Server { } } -async fn connections_listening_task( - config: ServerConfigurationData, - cert_mode: CertificateRetrievalMode, +#[derive(Resource)] +pub struct Server { + runtime: runtime::Handle, + endpoint: Option<Endpoint>, +} + +impl Server { + pub fn endpoint(&self) -> &Endpoint { + self.endpoint.as_ref().unwrap() + } + + pub fn endpoint_mut(&mut self) -> &mut Endpoint { + self.endpoint.as_mut().unwrap() + } + + pub fn get_endpoint(&self) -> Option<&Endpoint> { + self.endpoint.as_ref() + } + + pub fn get_endpoint_mut(&mut self) -> Option<&mut Endpoint> { + self.endpoint.as_mut() + } + + /// Run the server with the given [ServerConfigurationData] and [CertificateRetrievalMode] + pub fn open_endpoint( + &mut self, + config: ServerConfigurationData, + cert_mode: CertificateRetrievalMode, + ) -> Result<ServerCertificate, QuinnetError> { + let server_adr_str = format!("{}:{}", config.local_bind_host, config.port); + let server_addr = server_adr_str.parse::<SocketAddr>()?; + + // Endpoint configuration + let server_cert = retrieve_certificate(&config.host, cert_mode)?; + let mut server_config = ServerConfig::with_single_cert( + server_cert.cert_chain.clone(), + server_cert.priv_key.clone(), + )?; + Arc::get_mut(&mut server_config.transport) + .ok_or(QuinnetError::LockAcquisitionFailure)? + .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); + + let (from_clients_sender, from_clients_receiver) = + mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); + let (to_sync_server, from_async_server) = + mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + let (to_async_server, from_sync_server) = + broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); + + info!("Starting endpoint on: {} ...", server_adr_str); + + self.runtime.spawn(async move { + endpoint_task( + server_config, + server_addr, + to_sync_server.clone(), + from_sync_server, + from_clients_sender.clone(), + ) + .await; + }); + + self.endpoint = Some(Endpoint { + clients: HashMap::new(), + receiver: from_clients_receiver, + internal_receiver: from_async_server, + internal_sender: to_async_server.clone(), + }); + + Ok(server_cert) + } + + pub fn close_endpoint(&mut self) { + self.endpoint = None; + } + + /// Returns true if the server is currently listening for messages and connections. + pub fn is_listening(&self) -> bool { + match &self.endpoint { + Some(_) => true, + None => false, + } + } +} + +async fn endpoint_task( + endpoint_config: ServerConfig, + endpoint_adr: SocketAddr, to_sync_server: mpsc::Sender<InternalAsyncMessage>, mut from_sync_server: broadcast::Receiver<InternalSyncMessage>, from_clients_sender: mpsc::Sender<ClientPayload>, ) { - // TODO Fix: handle unwraps - let server_adr_str = format!("{}:{}", config.local_bind_host, config.port); - info!("Starting server on: {} ...", server_adr_str); - - let server_addr: SocketAddr = server_adr_str - .parse() - .expect("Failed to parse server address"); - - // Endpoint configuration - let (cert_chain, priv_key, cert_fingerprint) = - retrieve_certificate(&config.host, cert_mode).expect("Failed to retrieve certificate"); - to_sync_server - .send(InternalAsyncMessage::CertificateRetrieved( - CertificateRetrievedEvent { - fingerprint: cert_fingerprint, - }, - )) - .await - .expect("Failed to signal certificate retrieval to sync server"); - - let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key).unwrap(); - Arc::get_mut(&mut server_config.transport) - .unwrap() - .keep_alive_interval(Some(Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_S))); - let mut client_gen_id: ClientId = 0; let mut client_id_mappings = HashMap::new(); - let endpoint = - Endpoint::server(server_config, server_addr).expect("Failed to create server endpoint"); + let endpoint = QuinnEndpoint::server(endpoint_config, endpoint_adr) + .expect("Failed to create the endpoint"); // Start iterating over incoming connections/clients. while let Some(connecting) = endpoint.accept().await { @@ -312,50 +334,28 @@ async fn connections_listening_task( ); // Create a close channel for this client - let (close_sender, mut close_receiver): ( + let (close_sender, close_receiver): ( tokio::sync::broadcast::Sender<()>, tokio::sync::broadcast::Receiver<()>, ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); // Create an ordered reliable send channel for this client - let (to_client_sender, mut to_client_receiver) = + let (to_client_sender, to_client_receiver) = mpsc::channel::<Bytes>(DEFAULT_MESSAGE_QUEUE_SIZE); - let send_stream = connection.open_uni().await.expect( - format!( - "Failed to open unidirectional send stream for client: {}", - client_id - ) - .as_str(), - ); - - let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); - let to_sync_server_clone = to_sync_server.clone(); let close_sender_clone = close_sender.clone(); - let _network_broadcaster = tokio::spawn(async move { - tokio::select! { - _ = close_receiver.recv() => { - trace!("Unidirectional send stream forced to disconnected for client: {}", client_id) - } - _ = async { - while let Some(msg_bytes) = to_client_receiver.recv().await { - // TODO Perf: Batch frames for a send_all - // TODO Clean: Error handling - if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { - error!("Error while sending to client {}: {}", client_id, err); - error!("Client {} seems disconnected, closing resources", client_id); - if let Err(_) = close_sender_clone.send(()) { - error!("Failed to close all client streams & resources for client {}", client_id) - } - to_sync_server_clone.send( - InternalAsyncMessage::ClientLostConnection(client_id)) - .await - .expect("Failed to signal connection lost to sync server"); - }; - } - } => {} - } + let connection_clone = connection.clone(); + tokio::spawn(async move { + client_sender_task( + client_id, + connection_clone, + to_client_receiver, + close_receiver, + close_sender_clone, + to_sync_server_clone, + ) + .await }); // Signal the sync server of this new connection @@ -367,97 +367,115 @@ async fn connections_listening_task( })) .await .expect("Failed to signal connection to sync client"); - // Wait for the sync server to acknowledge the connection before spawning reception tasks. - while let Ok(message) = from_sync_server.recv().await { - match message { - InternalSyncMessage::ClientConnectedAck(id) => { - if id == client_id { - break; - } - } - _ => {} + while let Ok(InternalSyncMessage::ClientConnectedAck(id)) = from_sync_server.recv().await { + if id == client_id { + break; } } - // Spawn a task to listen for stream opened from this client + // Spawn a task to listen for streams opened by this client let from_client_sender_clone = from_clients_sender.clone(); - let mut uni_receivers: JoinSet<()> = JoinSet::new(); - let mut close_receiver = close_sender.subscribe(); let _client_receiver = tokio::spawn(async move { - tokio::select! { - _ = close_receiver.recv() => { - trace!("New Stream listener forced to disconnected for client: {}", client_id) - } - _ = async { - // For each new stream opened by the client - while let Ok(recv) = connection.accept_uni().await { - let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - - // Spawn a task to receive data on this stream. - let from_client_sender = from_client_sender_clone.clone(); - uni_receivers.spawn(async move { - while let Some(Ok(msg_bytes)) = frame_recv.next().await { - from_client_sender - .send(ClientPayload { - client_id: client_id, - msg: msg_bytes.into(), - }) - .await - .unwrap();// TODO Fix: error event - } - trace!("Unidirectional stream receiver ended for client: {}", client_id) - }); - } - } => { - trace!("New Stream listener ended for client: {}", client_id) - } - } - uni_receivers.shutdown().await; - trace!( - "All unidirectional stream receivers cleaned for client: {}", - client_id + client_receiver_task( + client_id, + connection, + close_sender.subscribe(), + from_client_sender_clone, ) + .await }); } } -fn start_async_server(mut commands: Commands, runtime: Res<AsyncRuntime>) { - // TODO Clean: Configure size - let (from_clients_sender, from_clients_receiver) = - mpsc::channel::<ClientPayload>(DEFAULT_MESSAGE_QUEUE_SIZE); - - let (to_sync_server, from_async_server) = - mpsc::channel::<InternalAsyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - let (to_async_server, mut from_sync_server) = - broadcast::channel::<InternalSyncMessage>(DEFAULT_INTERNAL_MESSAGE_CHANNEL_SIZE); - - commands.insert_resource(Server { - clients: HashMap::new(), - receiver: from_clients_receiver, - state: ServerState::Idle, - internal_receiver: from_async_server, - internal_sender: to_async_server.clone(), - }); +async fn client_sender_task( + client_id: ClientId, + connection: quinn::Connection, + mut to_client_receiver: tokio::sync::mpsc::Receiver<Bytes>, + mut close_receiver: tokio::sync::broadcast::Receiver<()>, + close_sender: tokio::sync::broadcast::Sender<()>, + to_sync_server: mpsc::Sender<InternalAsyncMessage>, +) { + let send_stream = connection.open_uni().await.expect( + format!( + "Failed to open unidirectional send stream for client: {}", + client_id + ) + .as_str(), + ); + + let mut framed_send_stream = FramedWrite::new(send_stream, LengthDelimitedCodec::new()); + + tokio::select! { + _ = close_receiver.recv() => { + trace!("Unidirectional send stream forced to disconnected for client: {}", client_id) + } + _ = async { + while let Some(msg_bytes) = to_client_receiver.recv().await { + // TODO Perf: Batch frames for a send_all + // TODO Clean: Error handling + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while sending to client {}: {}", client_id, err); + error!("Client {} seems disconnected, closing resources", client_id); + if let Err(_) = close_sender.send(()) { + error!("Failed to close all client streams & resources for client {}", client_id) + } + to_sync_server.send( + InternalAsyncMessage::ClientLostConnection(client_id)) + .await + .expect("Failed to signal connection lost to sync server"); + }; + } + } => {} + } +} - // Create async server task - runtime.spawn(async move { - // Wait for the sync server to to signal us to start listening - while let Ok(message) = from_sync_server.recv().await { - match message { - InternalSyncMessage::StartListening { config, cert_mode } => { - connections_listening_task( - config, - cert_mode, - to_sync_server.clone(), - to_async_server.subscribe(), - from_clients_sender.clone(), - ) - .await; - } - _ => {} +async fn client_receiver_task( + client_id: ClientId, + connection: quinn::Connection, + mut close_receiver: tokio::sync::broadcast::Receiver<()>, + from_clients_sender: mpsc::Sender<ClientPayload>, +) { + let mut uni_receivers: JoinSet<()> = JoinSet::new(); + tokio::select! { + _ = close_receiver.recv() => { + trace!("New Stream listener forced to disconnected for client: {}", client_id) + } + _ = async { + // For each new stream opened by the client + while let Ok(recv) = connection.accept_uni().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + + // Spawn a task to receive data on this stream. + let from_client_sender = from_clients_sender.clone(); + uni_receivers.spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_client_sender + .send(ClientPayload { + client_id: client_id, + msg: msg_bytes.into(), + }) + .await + .unwrap();// TODO Fix: error event + } + trace!("Unidirectional stream receiver ended for client: {}", client_id) + }); } + } => { + trace!("New Stream listener ended for client: {}", client_id) } + } + uni_receivers.shutdown().await; + trace!( + "All unidirectional stream receivers cleaned for client: {}", + client_id + ) +} + +fn create_server(mut commands: Commands, runtime: Res<AsyncRuntime>) { + commands.insert_resource(Server { + endpoint: None, + runtime: runtime.handle().clone(), }); } @@ -468,23 +486,25 @@ fn update_sync_server( mut connection_lost_events: EventWriter<ConnectionLostEvent>, mut certificate_retrieved_events: EventWriter<CertificateRetrievedEvent>, ) { - while let Ok(message) = server.internal_receiver.try_recv() { - match message { - InternalAsyncMessage::ClientConnected(connection) => { - let id = connection.client_id; - server.clients.insert(id, connection); - server - .internal_sender - .send(InternalSyncMessage::ClientConnectedAck(id)) - .unwrap(); - connection_events.send(ConnectionEvent { id: id }); - } - InternalAsyncMessage::ClientLostConnection(client_id) => { - server.clients.remove(&client_id); - connection_lost_events.send(ConnectionLostEvent { id: client_id }); - } - InternalAsyncMessage::CertificateRetrieved(event) => { - certificate_retrieved_events.send(event) + if let Some(endpoint) = server.get_endpoint_mut() { + while let Ok(message) = endpoint.internal_receiver.try_recv() { + match message { + InternalAsyncMessage::ClientConnected(connection) => { + let id = connection.client_id; + endpoint.clients.insert(id, connection); + endpoint + .internal_sender + .send(InternalSyncMessage::ClientConnectedAck(id)) + .unwrap(); + connection_events.send(ConnectionEvent { id: id }); + } + InternalAsyncMessage::ClientLostConnection(client_id) => { + endpoint.clients.remove(&client_id); + connection_lost_events.send(ConnectionLostEvent { id: client_id }); + } + InternalAsyncMessage::CertificateRetrieved(event) => { + certificate_retrieved_events.send(event) + } } } } @@ -503,7 +523,7 @@ impl Plugin for QuinnetServerPlugin { app.add_event::<ConnectionEvent>() .add_event::<ConnectionLostEvent>() .add_event::<CertificateRetrievedEvent>() - .add_startup_system_to_stage(StartupStage::PreStartup, start_async_server) + .add_startup_system_to_stage(StartupStage::PreStartup, create_server) .add_system_to_stage(CoreStage::PreUpdate, update_sync_server); if app.world.get_resource_mut::<AsyncRuntime>().is_none() { diff --git a/src/server/certificate.rs b/src/server/certificate.rs index 2dd207a..3cb9820 100644 --- a/src/server/certificate.rs +++ b/src/server/certificate.rs @@ -1,5 +1,4 @@ use std::{ - error::Error, fs::{self, File}, io::BufReader, path::Path, @@ -7,7 +6,7 @@ use std::{ use bevy::prelude::{trace, warn}; -use crate::shared::CertificateFingerprint; +use crate::shared::{CertificateFingerprint, QuinnetError}; /// Event raised when a certificate is retrieved on the server #[derive(Debug, Clone)] @@ -37,19 +36,18 @@ pub enum CertificateRetrievalMode { }, } +pub struct ServerCertificate { + pub cert_chain: Vec<rustls::Certificate>, + pub priv_key: rustls::PrivateKey, + pub fingerprint: CertificateFingerprint, +} + fn read_certs_from_files( cert_file: &String, key_file: &String, -) -> Result< - ( - Vec<rustls::Certificate>, - rustls::PrivateKey, - CertificateFingerprint, - ), - Box<dyn Error>, -> { +) -> Result<ServerCertificate, QuinnetError> { let mut cert_chain_reader = BufReader::new(File::open(cert_file)?); - let certs: Vec<rustls::Certificate> = rustls_pemfile::certs(&mut cert_chain_reader)? + let cert_chain: Vec<rustls::Certificate> = rustls_pemfile::certs(&mut cert_chain_reader)? .into_iter() .map(rustls::Certificate) .collect(); @@ -58,19 +56,23 @@ fn read_certs_from_files( let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)?; assert_eq!(keys.len(), 1); - let key = rustls::PrivateKey(keys.remove(0)); + let priv_key = rustls::PrivateKey(keys.remove(0)); - assert!(certs.len() >= 1); - let fingerprint = CertificateFingerprint::from(&certs[0]); + assert!(cert_chain.len() >= 1); + let fingerprint = CertificateFingerprint::from(&cert_chain[0]); - Ok((certs, key, fingerprint)) + Ok(ServerCertificate { + cert_chain, + priv_key, + fingerprint, + }) } fn write_certs_to_files( cert: &rcgen::Certificate, cert_file: &String, key_file: &String, -) -> Result<(), Box<dyn Error>> { +) -> Result<(), QuinnetError> { let pem_cert = cert.serialize_pem()?; let pem_key = cert.serialize_private_key_pem(); @@ -82,81 +84,59 @@ fn write_certs_to_files( fn generate_self_signed_certificate( server_host: &String, -) -> Result< - ( - Vec<rustls::Certificate>, - rustls::PrivateKey, - rcgen::Certificate, - CertificateFingerprint, - ), - Box<dyn Error>, -> { - let cert = rcgen::generate_simple_self_signed(vec![server_host.into()]).unwrap(); - let cert_der = cert.serialize_der().unwrap(); +) -> Result<(ServerCertificate, rcgen::Certificate), QuinnetError> { + let cert = rcgen::generate_simple_self_signed(vec![server_host.into()])?; + let cert_der = cert.serialize_der()?; let priv_key = rustls::PrivateKey(cert.serialize_private_key_der()); let rustls_cert = rustls::Certificate(cert_der.clone()); let fingerprint = CertificateFingerprint::from(&rustls_cert); let cert_chain = vec![rustls_cert]; - Ok((cert_chain, priv_key, cert, fingerprint)) + Ok(( + ServerCertificate { + cert_chain, + priv_key, + fingerprint, + }, + cert, + )) } pub(crate) fn retrieve_certificate( server_host: &String, cert_mode: CertificateRetrievalMode, -) -> Result< - ( - Vec<rustls::Certificate>, - rustls::PrivateKey, - CertificateFingerprint, - ), - Box<dyn Error>, -> { +) -> Result<ServerCertificate, QuinnetError> { match cert_mode { CertificateRetrievalMode::GenerateSelfSigned => { - trace!("Generating a new self-signed certificate"); - match generate_self_signed_certificate(server_host) { - Ok((cert_chain, priv_key, _rcgen_cert, fingerprint)) => { - Ok((cert_chain, priv_key, fingerprint)) - } - Err(e) => Err(e), - } + let (server_cert, _rcgen_cert) = generate_self_signed_certificate(server_host)?; + trace!("Generatied a new self-signed certificate"); + Ok(server_cert) } CertificateRetrievalMode::LoadFromFile { cert_file, key_file, - } => match read_certs_from_files(&cert_file, &key_file) { - Ok((cert_chain, priv_key, fingerprint)) => { - trace!("Successfuly loaded cert and key from files"); - Ok((cert_chain, priv_key, fingerprint)) - } - Err(e) => Err(e), - }, + } => { + let server_cert = read_certs_from_files(&cert_file, &key_file)?; + trace!("Successfuly loaded cert and key from files"); + Ok(server_cert) + } CertificateRetrievalMode::LoadFromFileOrGenerateSelfSigned { save_on_disk, cert_file, key_file, } => { if Path::new(&cert_file).exists() && Path::new(&key_file).exists() { - match read_certs_from_files(&cert_file, &key_file) { - Ok((cert_chain, priv_key, fingerprint)) => { - trace!("Successfuly loaded cert and key from files"); - Ok((cert_chain, priv_key, fingerprint)) - } - Err(e) => Err(e), - } + let server_cert = read_certs_from_files(&cert_file, &key_file)?; + trace!("Successfuly loaded cert and key from files"); + Ok(server_cert) } else { - warn!("{} and/or {} do not exist, could not load existing certificate. Generating a self-signed one.", cert_file, key_file); - match generate_self_signed_certificate(server_host) { - Ok((cert_chain, priv_key, rcgen_cert, fingerprint)) => { - if save_on_disk { - write_certs_to_files(&rcgen_cert, &cert_file, &key_file)?; - trace!("Successfuly saved cert and key to files"); - } - Ok((cert_chain, priv_key, fingerprint)) - } - Err(e) => Err(e), + warn!("{} and/or {} do not exist, could not load existing certificate. Generating a new self-signed certificate.", cert_file, key_file); + let (server_cert, rcgen_cert) = generate_self_signed_certificate(server_host)?; + if save_on_disk { + write_certs_to_files(&rcgen_cert, &cert_file, &key_file)?; + trace!("Successfuly saved cert and key to files"); } + Ok(server_cert) } } } |