diff --git a/src/comms.rs b/src/comms.rs deleted file mode 100755 index 407f14c..0000000 --- a/src/comms.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::error::Error; - -use aes_gcm::{aead::consts::U12, aes::Aes256, AesGcm}; -use base64::{engine::general_purpose, Engine}; -use rand::rngs::OsRng; -use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, - net::tcp::{ReadHalf, WriteHalf}, -}; - -use crate::crypto; - -pub async fn send( - writer: &mut BufWriter>, - cipher: Option<&mut AesGcm>, - rng: Option<&mut OsRng>, - data: &Vec, -) -> Result<(), Box> { - let enc: Vec; - - if let (Some(cipher), Some(rng)) = (cipher, rng) { - enc = crypto::aes_encrypt(data, cipher, rng)?; - } else { - enc = data.clone(); - } - - let mut encoded = general_purpose::STANDARD_NO_PAD - .encode(enc) - .as_bytes() - .to_vec(); - encoded.push(b':'); - writer.write_all(&encoded).await?; - writer.flush().await?; - - Ok(()) -} - -pub async fn recv( - reader: &mut BufReader>, - cipher: Option<&mut AesGcm>, -) -> Result, Box> { - let mut buf = Vec::new(); - let n = reader.read_until(b':', &mut buf).await?; - - if n == 0 { - return Err("Received 0 bytes from the socket".into()); - } - - buf.pop(); - buf = general_purpose::STANDARD_NO_PAD.decode(&buf)?.to_vec(); - - if let Some(cipher) = cipher { - buf = crypto::aes_decrypt(&buf, cipher)?; - } else { - buf = buf.clone(); - } - - Ok(buf) -} diff --git a/src/crypto.rs b/src/crypto.rs old mode 100755 new mode 100644 index 2ac9592..233b974 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -5,85 +5,85 @@ use aes_gcm::{ aes::Aes256, Aes256Gcm, AesGcm, KeyInit, Nonce, }; -use rand::{distributions::Alphanumeric, rngs::OsRng, Rng, RngCore}; -use tokio::{ - io::{BufReader, BufWriter}, - net::tcp::{ReadHalf, WriteHalf}, -}; +use rand::{rngs::OsRng, RngCore}; use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret}; -use crate::comms; +use crate::sockets::SocketHandler; const AES_NONCE_SIZE: usize = 12; const DH_PBK_SIZE: usize = 32; -async fn edh( - reader: &mut BufReader>, - writer: &mut BufWriter>, - go_first: bool, -) -> Result> { - let buf: Vec; - let own_sec = EphemeralSecret::new(OsRng); - let own_pbk = PublicKey::from(&own_sec); - let msg = own_pbk.as_bytes().to_vec(); +#[derive(Clone)] +pub struct Crypto { + cipher: AesGcm, + rng: OsRng, +} - if go_first { - comms::send(writer, None, None, &msg).await?; - buf = comms::recv(reader, None).await?; - } else { - buf = comms::recv(reader, None).await?; - comms::send(writer, None, None, &msg).await?; +impl Crypto { + pub async fn new( + handler: &mut SocketHandler<'_>, + go_first: bool, + ) -> Result> { + let secret = Self::ecdh(handler, go_first).await?; + let cipher = Aes256Gcm::new(secret.as_bytes().into()); + let rng = OsRng; + + Ok(Self { cipher, rng }) } - let slice: [u8; DH_PBK_SIZE] = buf[..DH_PBK_SIZE].try_into()?; - let recv_pbk = PublicKey::from(slice); + async fn ecdh( + handler: &mut SocketHandler<'_>, + go_first: bool, + ) -> Result> { + let buf: Vec; + let own_sec = EphemeralSecret::new(OsRng); + let own_pbk = PublicKey::from(&own_sec); + let msg = own_pbk.as_bytes().to_vec(); - Ok(own_sec.diffie_hellman(&recv_pbk)) -} + if go_first { + handler.send(&msg).await?; + buf = handler.recv().await?; + } else { + buf = handler.recv().await?; + handler.send(&msg).await?; + } -pub async fn aes_cipher( - reader: &mut BufReader>, - writer: &mut BufWriter>, - go_first: bool, -) -> Result, Box> { - let secret = edh(reader, writer, go_first).await?; - Ok(Aes256Gcm::new(secret.as_bytes().into())) -} + let slice: [u8; DH_PBK_SIZE] = buf[..DH_PBK_SIZE].try_into()?; + let recv_pbk = PublicKey::from(slice); -fn generate_nonce(rng: &mut impl RngCore) -> Nonce { - let mut nonce = Nonce::default(); - rng.fill_bytes(&mut nonce); + Ok(own_sec.diffie_hellman(&recv_pbk)) + } - nonce -} + fn nonce(&self) -> Nonce { + let mut nonce = Nonce::default(); + self.rng.fill_bytes(&mut nonce); -pub fn aes_encrypt( - data: &Vec, - cipher: &mut AesGcm, - rng: &mut OsRng, -) -> Result, Box> { - let nonce = generate_nonce(rng); - let encrypted = match cipher.encrypt(&nonce, data.as_ref()) { - Ok(data) => data, - Err(_) => return Err("AES encryption failed".into()), - }; - let mut data = nonce.to_vec(); - data.extend_from_slice(&encrypted); + nonce + } - Ok(data) -} + pub async fn encrypt(&self, data: &[u8]) -> Result, Box> { + let nonce = self.nonce(); + let encrypted = match self.cipher.encrypt(&nonce, data.as_ref()) { + Ok(data) => data, + Err(e) => return Err(format!("Encryption failed: {}", e).into()), + }; -pub fn aes_decrypt( - data: &[u8], - cipher: &mut AesGcm, -) -> Result, Box> { - let (nonce_bytes, data) = data.split_at(AES_NONCE_SIZE); - let decrypted = match cipher.decrypt(Nonce::from_slice(nonce_bytes), data.as_ref()) { - Ok(data) => data, - Err(_) => return Err("AES decryption failed".into()), - }; + let mut data = nonce.to_vec(); + data.extend_from_slice(&encrypted); - Ok(decrypted) + Ok(data) + } + + pub async fn decrypt(&self, data: &[u8]) -> Result, Box> { + let (nonce_bytes, data) = data.split_at(AES_NONCE_SIZE); + let nonce = Nonce::from_slice(nonce_bytes); + let decrypted = match self.cipher.decrypt(nonce, data.as_ref()) { + Ok(data) => data, + Err(e) => return Err(format!("Decryption failed: {}", e).into()), + }; + + Ok(decrypted) + } } pub fn try_hash(path: &Path) -> Result> { @@ -92,31 +92,4 @@ pub fn try_hash(path: &Path) -> Result> { Ok(hash) } -pub fn keygen() -> String { - rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(8) - .map(char::from) - .collect::() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn aes_implementations() { - use aes_gcm::aead; - - let mut gen_rng = aead::OsRng; - let key = Aes256Gcm::generate_key(&mut gen_rng); - let mut cipher = Aes256Gcm::new(&key); - - let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; - let mut aes_rng = OsRng; - let enc = aes_encrypt(&data, &mut cipher, &mut aes_rng).unwrap(); - let dec = aes_decrypt(&enc, &mut cipher).unwrap(); - - assert_eq!(data, dec); - } -} +// TODO: unit test if deemed necessary diff --git a/src/lib.rs b/src/lib.rs index ac229c8..9128f46 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,8 @@ pub mod cli; pub mod common; -pub mod comms; pub mod connector; pub mod crypto; pub mod listener; pub mod parsers; +pub mod sockets; pub mod util; diff --git a/src/sockets.rs b/src/sockets.rs new file mode 100644 index 0000000..0b4d3b5 --- /dev/null +++ b/src/sockets.rs @@ -0,0 +1,80 @@ +use std::error::Error; + +use base64::{engine::general_purpose, Engine}; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::{ + tcp::{ReadHalf, WriteHalf}, + TcpStream, + }, +}; + +use crate::crypto::Crypto; + +pub struct SocketHandler<'a> { + writer: BufWriter>, + reader: BufReader>, + crypto: Option, +} + +impl<'a> SocketHandler<'a> { + pub fn new(socket: &'a mut TcpStream) -> Self { + let (reader, writer) = socket.split(); + let mut reader = BufReader::new(reader); + let mut writer = BufWriter::new(writer); + + Self { + writer, + reader, + crypto: None, + } + } + + pub fn set_crypto(&self, crypto: Crypto) { + // setting up AES cipher requires DH key exchange in plaintext, + // meaning crypto can't be initialized at the same time as the socket handler + self.crypto = Some(crypto); + } + + pub async fn sender(&mut self, data: &[u8]) -> Result<(), Box> { + let data = match &self.crypto { + Some(c) => c.encrypt(data).await?, + None => data.to_vec(), + }; + + self.send(&data).await?; + + Ok(()) + } + + pub async fn send(&mut self, data: &[u8]) -> Result<(), Box> { + self.writer.write_all(data).await?; + self.writer.flush().await?; + + Ok(()) + } + + pub async fn receiver(&mut self) -> Result, Box> { + let mut buf = self.recv().await?; + buf.pop(); + buf = general_purpose::STANDARD_NO_PAD.decode(&buf)?.to_vec(); + + let data = match &self.crypto { + Some(c) => c.decrypt(&buf).await?, + None => buf, + }; + + Ok(data) + } + + pub async fn recv(&mut self) -> Result, Box> { + let mut buf = Vec::new(); + let n = self.reader.read_until(b':', &mut buf).await?; + + if n == 0 { + return Err("Received 0 bytes from the socket".into()); + } + + Ok(buf) + } +}