diff options
Diffstat (limited to 'hbak_common/src/conn.rs')
-rw-r--r-- | hbak_common/src/conn.rs | 125 |
1 files changed, 113 insertions, 12 deletions
diff --git a/hbak_common/src/conn.rs b/hbak_common/src/conn.rs index 974aa63..c122007 100644 --- a/hbak_common/src/conn.rs +++ b/hbak_common/src/conn.rs @@ -1,21 +1,22 @@ -use crate::message::Target; +use crate::message::*; +use crate::system; use crate::{NetworkError, RemoteError}; +use std::io::Write; use std::net::{SocketAddr, TcpStream}; use std::time::Duration; +use chacha20poly1305::aead::generic_array::GenericArray; use chacha20poly1305::aead::stream::{DecryptorBE32, EncryptorBE32}; -use chacha20poly1305::XChaCha20Poly1305; +use chacha20poly1305::{Key, XChaCha20Poly1305}; +use subtle::ConstantTimeEq; /// TCP connect timeout. Connection attempt is aborted if remote doesn't respond. const CONNECT_TIMEOUT: Duration = Duration::from_secs(30); /// The valid states of an [`AuthConn`]. -#[derive(Debug, Default, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq)] enum AuthConnState { - /// Authentication has not started. - #[default] - Idle, /// A `Hello` message has been sent. Awaiting the `ServerAuth` response. Handshake { /// The challenge sent in the `Hello` message. @@ -26,14 +27,11 @@ enum AuthConnState { /// A `ServerAuth` message has been received and a `ClientAuth` reaction has been sent. /// Awaiting the `Encrypt` response. Proof { + /// The shared secret for mutual authentication and transport encryption. + key: Vec<u8>, /// The nonce for transport encryption. nonce: Vec<u8>, }, - /// An `Encrypt` message has been received and encryption has been configured. - /// Further plaintext reads or writes are not allowed. Transformation is imminent. - Encrypted, - /// Authentication or encryption setup has failed. The connection should be terminated. - Failed(RemoteError), } /// The valid states of an [`AuthServ`]. @@ -117,13 +115,100 @@ impl AuthConn { pub fn new(addr: &SocketAddr) -> Result<Self, NetworkError> { Ok(TcpStream::connect_timeout(addr, CONNECT_TIMEOUT)?.into()) } + + /// Performs mutual authentication and encryption of the connection + /// using the provided node name and passphrase, + /// returning a [`StreamConn`] on success. + pub fn secure_stream<P: AsRef<[u8]>>( + mut self, + node_name: String, + passphrase: P, + ) -> Result<StreamConn, NetworkError> { + // No need to check for `AuthConnState::Idle`, consuming the `AuthConn` + // guarantees that this function can never be called again. + + // Limit variables to this scope so they aren't used in the main loop by accident. + { + let AuthConnState::Handshake { challenge, nonce } = &self.state else { + unreachable!() + }; + + self.send_message(&CryptoMessage::Hello(Hello { + node_name, + challenge: challenge.to_vec(), + nonce: nonce.to_vec(), + }))?; + } + + loop { + let message = self.recv_message()?; + + match (self.state, message) { + ( + AuthConnState::Handshake { challenge, nonce }, + CryptoMessage::ServerAuth(server_auth), + ) => { + let server_auth = server_auth?; + + let key = system::derive_key(&server_auth.verifier, &passphrase)?; + let server_proof = system::hash_hmac(&key, &challenge); + + if server_auth.proof.ct_eq(&server_proof).into() { + let proof = system::hash_hmac(&key, &server_auth.challenge); + + self.state = AuthConnState::Proof { key, nonce }; + self.send_message(&CryptoMessage::ClientAuth(Ok(ClientAuth { proof })))?; + } else { + self.state = AuthConnState::Handshake { challenge, nonce }; + self.send_message(&CryptoMessage::ClientAuth(Err( + RemoteError::AccessDenied, + )))?; + + return Err(RemoteError::Unauthorized.into()); + } + } + (AuthConnState::Proof { key, nonce }, CryptoMessage::Encrypt(encrypt)) => { + encrypt?; + return Ok(StreamConn::from_conn(self.stream, key, nonce)); + } + (AuthConnState::Handshake { challenge, nonce }, _) => { + self.state = AuthConnState::Handshake { challenge, nonce }; + self.send_message(&CryptoMessage::ClientAuth(Err( + RemoteError::IllegalTransition, + )))?; + + return Err(NetworkError::IllegalTransition); + } + (AuthConnState::Proof { key, nonce }, _) => { + self.state = AuthConnState::Proof { key, nonce }; + self.send_message(&CryptoMessage::Error(RemoteError::IllegalTransition))?; + + return Err(NetworkError::IllegalTransition); + } + } + } + } + + fn send_message(&self, message: &CryptoMessage) -> Result<(), NetworkError> { + let buf = bincode::serialize(message)?; + (&self.stream).write_all(&buf)?; + + Ok(()) + } + + fn recv_message(&self) -> Result<CryptoMessage, NetworkError> { + Ok(bincode::deserialize_from(&self.stream)?) + } } impl From<TcpStream> for AuthConn { fn from(stream: TcpStream) -> Self { Self { stream, - state: AuthConnState::default(), + state: AuthConnState::Handshake { + challenge: system::random_bytes(32), + nonce: system::random_bytes(32), + }, } } } @@ -154,3 +239,19 @@ pub struct StreamConn { decryptor: DecryptorBE32<XChaCha20Poly1305>, state: StreamConnState, } + +impl StreamConn { + /// Constructs a new `StreamConn` from a [`std::net::TcpStream`], + /// encryption key and nonce. + pub(crate) fn from_conn(stream: TcpStream, key: Vec<u8>, nonce: Vec<u8>) -> Self { + let key = Key::from_slice(&key); + let nonce = GenericArray::from_slice(&nonce); + + Self { + stream, + encryptor: EncryptorBE32::new(key, nonce), + decryptor: DecryptorBE32::new(key, nonce), + state: StreamConnState::default(), + } + } +} |