diff --git a/server/src/client.rs b/server/src/client.rs index 0318bcb..278feaf 100644 --- a/server/src/client.rs +++ b/server/src/client.rs @@ -1,18 +1,20 @@ -use std::sync::Arc; +use tokio::net::TcpStream; use std::cmp::Ordering; +use std::sync::Arc; use uuid::Uuid; use futures::lock::Mutex; -use tokio::sync::mpsc::{Sender, Receiver, channel}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; -use crate::network::SocketHandler; -use crate::messages::ClientMessage; -use crate::messages::ServerMessage; -use crate::prelude::StreamMessageSender; use foundation::ClientDetails; +use foundation::network::SocketHandler; +use foundation::prelude::StreamMessageSender; use foundation::messages::client::{ClientStreamIn, ClientStreamOut}; +use crate::messages::ClientMessage; +use crate::messages::ServerMessage; + /// # Client /// This struct represents a connected user. /// @@ -34,7 +36,7 @@ pub struct Client { tx: Sender, rx: Mutex>, - socket_sender: Arc, + socket_sender: Arc>, } // client funciton implmentations @@ -43,7 +45,7 @@ impl Client { uuid: String, username: String, address: String, - socket_sender: Arc, + socket_sender: Arc>, server_channel: Sender, ) -> Arc { let (sender, receiver) = channel(1024); @@ -53,7 +55,7 @@ impl Client { uuid: Uuid::parse_str(&uuid).expect("invalid id"), username, address, - public_key: None + public_key: None, }, server_channel: Mutex::new(server_channel), @@ -61,12 +63,10 @@ impl Client { tx: sender, rx: Mutex::new(receiver), - }) } pub fn start(self: &Arc) { - let t1_client = self.clone(); let t2_client = self.clone(); @@ -76,7 +76,11 @@ impl Client { let client = t1_client; - client.socket_sender.send::(ClientStreamOut::Connected).await.expect("error"); + client + .socket_sender + .send::(ClientStreamOut::Connected) + .await + .expect("error"); loop { let command = client.socket_sender.recv::().await; @@ -87,23 +91,36 @@ impl Client { return; } Ok(ClientStreamIn::SendMessage { to, content }) => { - println!("[Client {:?}]: send message to: {:?}", &client.details.uuid, &to); + println!( + "[Client {:?}]: send message to: {:?}", + &client.details.uuid, &to + ); let lock = client.server_channel.lock().await; - let _ = lock.send(ServerMessage::ClientSendMessage { - from: client.details.uuid, - to, - content, - }).await; + let _ = lock + .send(ServerMessage::ClientSendMessage { + from: client.details.uuid, + to, + content, + }) + .await; } Ok(ClientStreamIn::Update) => { println!("[Client {:?}]: update received", &client.details.uuid); let lock = client.server_channel.lock().await; - let _ = lock.send(ServerMessage::ClientUpdate { to: client.details.uuid }).await; + let _ = lock + .send(ServerMessage::ClientUpdate { + to: client.details.uuid, + }) + .await; } _ => { println!("[Client {:?}]: command not found", &client.details.uuid); let lock = client.server_channel.lock().await; - let _ = lock.send(ServerMessage::ClientError { to: client.details.uuid }).await; + let _ = lock + .send(ServerMessage::ClientError { + to: client.details.uuid, + }) + .await; } } } @@ -111,7 +128,7 @@ impl Client { // client channel read thread tokio::spawn(async move { - use ClientMessage::{Disconnect, Message, SendClients, Error}; + use ClientMessage::{Disconnect, Error, Message, SendClients}; let client = t2_client; @@ -125,32 +142,42 @@ impl Client { match message { Disconnect => { let lock = client.server_channel.lock().await; - let _ = lock.send(ServerMessage::ClientDisconnected { id: client.details.uuid }).await; - return + let _ = lock + .send(ServerMessage::ClientDisconnected { + id: client.details.uuid, + }) + .await; + return; } - Message { from, content } => - client.socket_sender.send::( - ClientStreamOut::UserMessage { from, content } - ).await.expect("error sending message"), - - SendClients { clients } => { - let client_details_vec: Vec = - clients.iter().map(|client| &client.details) - .cloned().collect(); + Message { from, content } => client + .socket_sender + .send::(ClientStreamOut::UserMessage { from, content }) + .await + .expect("error sending message"), - client.socket_sender.send::( - ClientStreamOut::ConnectedClients { + SendClients { clients } => { + let client_details_vec: Vec = clients + .iter() + .map(|client| &client.details) + .cloned() + .collect(); + + client + .socket_sender + .send::(ClientStreamOut::ConnectedClients { clients: client_details_vec, - } - ).await.expect("error sending message"); - }, - Error => - client.socket_sender.send::( - ClientStreamOut::Error - ).await.expect("error sending message"), + }) + .await + .expect("error sending message"); + } + Error => client + .socket_sender + .send::(ClientStreamOut::Error) + .await + .expect("error sending message"), } } - }); + }); } pub async fn send_message(self: &Arc, msg: ClientMessage) { diff --git a/server/src/client_manager.rs b/server/src/client_manager.rs index d750306..7e8d8c3 100644 --- a/server/src/client_manager.rs +++ b/server/src/client_manager.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; use std::sync::Arc; -use uuid::Uuid; -use tokio::sync::mpsc::{channel, Receiver, Sender}; use futures::lock::Mutex; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use uuid::Uuid; use crate::client::Client; use crate::messages::ClientMessage; @@ -37,19 +37,17 @@ impl ClientManager { } pub fn start(self: &Arc) { - let client_manager = self.clone(); tokio::spawn(async move { - - use ClientMgrMessage::{Add, Remove, SendClients, SendMessage, SendError}; + use ClientMgrMessage::{Add, Remove, SendClients, SendError, SendMessage}; loop { let mut receiver = client_manager.rx.lock().await; let message = receiver.recv().await.unwrap(); println!("[Client manager]: recieved message: {:?}", message); - + match message { Add(client) => { println!("[Client Manager]: adding new client"); @@ -66,25 +64,28 @@ impl ClientManager { } } SendMessage { to, from, content } => { - client_manager.send_to_client(&to, ClientMessage::Message { from, content }).await; + client_manager + .send_to_client(&to, ClientMessage::Message { from, content }) + .await; } SendClients { to } => { let lock = client_manager.clients.lock().await; if let Some(client) = lock.get(&to) { - let clients_vec: Vec> = - lock.values().cloned().collect(); + let clients_vec: Vec> = lock.values().cloned().collect(); - client.send_message(ClientMessage::SendClients { - clients: clients_vec, - }).await + client + .send_message(ClientMessage::SendClients { + clients: clients_vec, + }) + .await } - }, + } SendError { to } => { let lock = client_manager.clients.lock().await; if let Some(client) = lock.get(&to) { client.send_message(ClientMessage::Error).await } - }, + } #[allow(unreachable_patterns)] _ => println!("[Client manager]: not implemented"), } @@ -99,10 +100,7 @@ impl ClientManager { } } - pub async fn send_message( - self: Arc, - message: ClientMgrMessage) - { + pub async fn send_message(self: Arc, message: ClientMgrMessage) { let _ = self.tx.send(message).await; } } diff --git a/server/src/main.rs b/server/src/main.rs index 221008b..b6b69b4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,11 +1,9 @@ pub mod client; pub mod client_manager; -pub mod messages; pub mod network_manager; -pub mod network; -pub mod server; -pub mod encryption; pub mod prelude; +pub mod server; +pub mod messages; use std::io; @@ -34,5 +32,5 @@ async fn main() -> io::Result<()> { let server = Server::new().unwrap(); server.start().await; - Ok(()) + Ok(()) } diff --git a/server/src/messages.rs b/server/src/messages.rs index f703171..b9306c2 100644 --- a/server/src/messages.rs +++ b/server/src/messages.rs @@ -28,7 +28,7 @@ pub enum ClientMgrMessage { }, SendError { to: Uuid, - } + }, } #[derive(Debug)] @@ -48,6 +48,6 @@ pub enum ServerMessage { to: Uuid, }, ClientError { - to: Uuid - } + to: Uuid, + }, } diff --git a/server/src/network/mod.rs b/server/src/network/mod.rs index c622d7d..b6e420f 100644 --- a/server/src/network/mod.rs +++ b/server/src/network/mod.rs @@ -1,23 +1,24 @@ -use std::sync::Arc; -use std::io::Write; -use std::io::Error; use std::fmt::Debug; +use std::io::Error; +use std::io::Write; +use std::sync::Arc; use async_trait::async_trait; -use serde::Serialize; use serde::de::DeserializeOwned; +use serde::Serialize; use tokio::io::split; -use tokio::sync::Mutex; -use tokio::io::ReadHalf; -use tokio::io::BufReader; -use tokio::io::WriteHalf; -use tokio::net::TcpStream; -use tokio::io::AsyncWriteExt; use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; +use tokio::sync::Mutex; use crate::prelude::StreamMessageSender; +use crate::prelude::TransformerFn; -type TransformerVec = Vec &[u8]>; +type TransformerVec = Vec; pub struct SocketHandler { stream_tx: Mutex>, @@ -41,117 +42,147 @@ impl SocketHandler { }) } - pub async fn push_layer( - self: &Arc, - send_func: fn(&[u8]) -> &[u8], - recv_func: fn(&[u8]) -> &[u8], - ) { + pub async fn push_layer(self: &Arc, send_func: TransformerFn, recv_func: TransformerFn) { let mut send_lock = self.send_transformer.lock().await; let mut recv_lock = self.recv_transformer.lock().await; send_lock.push(send_func); - recv_lock.reverse(); recv_lock.push(recv_func); - recv_lock.reverse(); } - pub async fn pop_layer(self: &Arc,) { + pub async fn pop_layer(self: &Arc) { let mut send_lock = self.send_transformer.lock().await; let mut recv_lock = self.recv_transformer.lock().await; - let _ = send_lock.pop(); - - recv_lock.reverse(); let _ = recv_lock.pop(); - recv_lock.reverse(); } } #[async_trait] impl StreamMessageSender for SocketHandler { - async fn send - (self: &Arc, message: TOutMessage) -> Result<(), Error> - { + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), Error> { let mut out_buffer: Vec = Vec::new(); let message_string = serde_json::to_string(&message)?; writeln!(out_buffer, "{}", message_string)?; + + println!("[SocketHandler:send] message_before: {:?}", &out_buffer); + + let transformers = self.send_transformer.lock().await; + let iter = transformers.iter(); + + for func in iter { + out_buffer = (**func)(&out_buffer); + } + + println!("[SocketHandler:send] message_after: {:?}", &out_buffer); + let mut lock = self.stream_tx.lock().await; lock.write_all(&out_buffer).await?; lock.flush().await?; Ok(()) } - async fn recv<'de, TInMessage: DeserializeOwned + Send> - (self: &Arc) -> Result - { + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result { let mut in_buffer = String::new(); let mut lock = self.stream_rx.lock().await; lock.read_line(&mut in_buffer).await?; - let message: TInMessage = serde_json::from_str(&in_buffer) - .expect("[StreamMessageSender:recv] deserialisation failed"); + + println!("[SocketHandler:recv] message_before: {:?}", &in_buffer); + + let transformers = self.recv_transformer.lock().await; + let iter = transformers.iter(); + + let mut in_buffer = in_buffer.into_bytes(); + + for func in iter { + in_buffer = (**func)(&in_buffer); + } + + println!("[SocketHandler:recv] message_after: {:?}", &in_buffer); + + let in_buffer = String::from_utf8(in_buffer).expect("invalid utf_8"); + + let message: TInMessage = serde_json::from_str(&in_buffer).unwrap(); Ok(message) } } impl Debug for SocketHandler { - - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) - -> std::result::Result<(), std::fmt::Error> { - write!(f, "[SocketSender]") + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "[SocketSender]") } } #[cfg(test)] mod test { + use tokio::runtime::Runtime; + use std::sync::Once; use std::time::Duration; - use tokio::net::TcpStream; - use tokio::net::TcpListener; - use tokio::time::sleep; + use tokio::task; use tokio::io::{AsyncReadExt, AsyncWriteExt}; - + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::time::sleep; + use super::SocketHandler; + use crate::helpers::start_server; + use crate::helpers::create_test_shared; use crate::prelude::StreamMessageSender; + use crate::encryption::create_encryption_transformers; - async fn start_server() { - let listener = TcpListener::bind("127.0.0.1:5600").await.expect("failed to create listener"); - let mut buf = [0; 1024]; - loop { - let (mut socket, _) = listener.accept().await.expect("failed to accept connection"); + static SERVER_INIT: Once = Once::new(); - tokio::spawn(async move { - let n = match socket.read(&mut buf).await { - // socket closed - Ok(n) if n == 0 => return, - Ok(n) => n, - Err(e) => { - println!("failed to read from socket; err = {:?}", e); - return; - } - }; - - // Write the data back - if let Err(e) = socket.write_all(&buf[0..n]).await { - println!("failed to write to socket; err = {:?}", e); - return; - } + fn setup() { + SERVER_INIT.call_once(|| { + std::thread::spawn(|| { + let rt = Runtime::new().unwrap(); + rt.block_on(start_server()) + }); - } + }) } #[tokio::test] - async fn test_socket_sender() { - tokio::spawn(start_server()); + async fn test_socket_sender() { + setup(); + task::spawn(start_server()); - let socket = TcpStream::connect("localhost:5600").await.expect("failed to connect"); + let socket = TcpStream::connect("localhost:5600") + .await + .expect("failed to connect"); - sleep(Duration::from_secs(1)).await; - let handle = SocketHandler::new(socket); let _ = handle.send::(true).await; let message = handle.recv::().await.unwrap(); assert!(message); } -} \ No newline at end of file + + #[tokio::test] + async fn test_socket_sender_with_encryption() { + setup(); + task::spawn(start_server()); + + let socket = TcpStream::connect("localhost:5600") + .await + .unwrap(); + + let shared = create_test_shared(); + let (en, de) = create_encryption_transformers(shared, b"12345678901234567890123456789011"); + let handle = SocketHandler::new(socket); + + handle.push_layer(en, de).await; + + let _ = handle.send::(true).await; + let message = handle.recv::().await.unwrap(); + + assert!(message); + } +} diff --git a/server/src/network_manager.rs b/server/src/network_manager.rs index 7d5e478..7329656 100644 --- a/server/src/network_manager.rs +++ b/server/src/network_manager.rs @@ -3,11 +3,12 @@ use std::sync::Arc; use tokio::net::TcpListener; use tokio::sync::mpsc::Sender; -use crate::client::Client; -use crate::network::SocketHandler; -use crate::messages::ServerMessage; -use crate::prelude::StreamMessageSender; +use foundation::prelude::StreamMessageSender; use foundation::messages::network::{NetworkSockIn, NetworkSockOut}; +use foundation::network::SocketHandler; + +use crate::client::Client; +use crate::messages::ServerMessage; pub struct NetworkManager { address: String, @@ -23,11 +24,12 @@ impl NetworkManager { } pub fn start(self: &Arc) { - let network_manager = self.clone(); tokio::spawn(async move { - let listener = TcpListener::bind(network_manager.address.clone()).await.unwrap(); + let listener = TcpListener::bind(network_manager.address.clone()) + .await + .unwrap(); loop { let (connection, _) = listener.accept().await.unwrap(); @@ -35,22 +37,21 @@ impl NetworkManager { let server_channel = network_manager.server_channel.clone(); tokio::spawn(async move { + stream_sender + .send::(NetworkSockOut::Request) + .await + .expect("failed to send message"); - stream_sender.send::(NetworkSockOut::Request) - .await.expect("failed to send message"); - - if let Ok(request) = - stream_sender.recv::().await - { - + if let Ok(request) = stream_sender.recv::().await { match request { NetworkSockIn::Info => { - stream_sender.send( - NetworkSockOut::GotInfo { + stream_sender + .send(NetworkSockOut::GotInfo { server_name: "oof", server_owner: "michael", - } - ).await.expect("failed to send got info"); + }) + .await + .expect("failed to send got info"); } NetworkSockIn::Connect { uuid, @@ -66,14 +67,13 @@ impl NetworkManager { server_channel.clone(), ); let _ = server_channel - .send(ServerMessage::ClientConnected { - client: new_client, - }).await; + .send(ServerMessage::ClientConnected { client: new_client }) + .await; } } } }); } - }); + }); } } diff --git a/server/src/prelude.rs b/server/src/prelude.rs index e9b8154..1242c64 100644 --- a/server/src/prelude.rs +++ b/server/src/prelude.rs @@ -2,12 +2,18 @@ use std::sync::Arc; use async_trait::async_trait; -use serde::Serialize; use serde::de::DeserializeOwned; - +use serde::Serialize; #[async_trait] pub trait StreamMessageSender { - async fn send(self: &Arc, message: TOutMessage) -> Result<(), std::io::Error>; - async fn recv<'de, TInMessage: DeserializeOwned + Send>(self: &Arc) -> Result; + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), std::io::Error>; + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result; } + +pub type TransformerFn = Box Vec + Send + Sync>; diff --git a/server/src/server.rs b/server/src/server.rs index 9d72786..cb26327 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,9 +1,9 @@ use std::sync::Arc; // use crossbeam_channel::{unbounded, Receiver}; -use uuid::Uuid; -use tokio::sync::mpsc::{channel, Receiver}; use futures::lock::Mutex; +use tokio::sync::mpsc::{channel, Receiver}; +use uuid::Uuid; use crate::client_manager::ClientManager; use crate::messages::ClientMgrMessage; @@ -22,7 +22,7 @@ pub enum ServerMessages { /// authors: @michael-bailey, @Mitch161 /// This Represents a server instance. /// it is componsed of a client manager and a network manager -/// +/// pub struct Server { client_manager: Arc, network_manager: Arc, @@ -34,19 +34,14 @@ impl Server { pub fn new() -> Result, Box> { let (sender, receiver) = channel(1024); - Ok( - Arc::new( - Server { - client_manager: ClientManager::new(sender.clone()), - network_manager: NetworkManager::new("5600".to_string(), sender), - receiver: Mutex::new(receiver), - } - ) - ) + Ok(Arc::new(Server { + client_manager: ClientManager::new(sender.clone()), + network_manager: NetworkManager::new("5600".to_string(), sender), + receiver: Mutex::new(receiver), + })) } pub async fn start(self: &Arc) { - // start client manager and network manager self.network_manager.clone().start(); self.client_manager.clone().start(); @@ -54,7 +49,6 @@ impl Server { // clone block items let server = self.clone(); - use ClientMgrMessage::{Add, Remove, SendMessage}; loop { @@ -64,25 +58,39 @@ impl Server { match message { ServerMessage::ClientConnected { client } => { - server.client_manager.clone() - .send_message(Add(client)).await + server + .client_manager + .clone() + .send_message(Add(client)) + .await } ServerMessage::ClientDisconnected { id } => { println!("disconnecting client {:?}", id); server.client_manager.clone().send_message(Remove(id)).await; } - ServerMessage::ClientSendMessage { from, to, content } => server - .client_manager.clone() - .send_message(SendMessage { from, to, content }).await, - ServerMessage::ClientUpdate { to } => server - .client_manager.clone() - .send_message(ClientMgrMessage::SendClients { to }).await, - ServerMessage::ClientError { to } => server - .client_manager.clone() - .send_message(ClientMgrMessage::SendError {to}).await, + ServerMessage::ClientSendMessage { from, to, content } => { + server + .client_manager + .clone() + .send_message(SendMessage { from, to, content }) + .await + } + ServerMessage::ClientUpdate { to } => { + server + .client_manager + .clone() + .send_message(ClientMgrMessage::SendClients { to }) + .await + } + ServerMessage::ClientError { to } => { + server + .client_manager + .clone() + .send_message(ClientMgrMessage::SendError { to }) + .await + } } } } } } -