diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index edf72fb..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Rust - -on: - push: - branches: [ ref-method ] - pull_request: - branches: [ master ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - - steps: - - uses: actions/checkout@v2 - - name: check - run: cargo check --verbose - - name: Build - run: cargo build --verbose - diff --git a/Cargo.toml b/Cargo.toml index 00e10ba..be07e27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ members = [ 'server', 'client', 'serverctl' -] \ No newline at end of file +] diff --git a/README.md b/README.md index 260be15..de85e4c 100644 --- a/README.md +++ b/README.md @@ -1,32 +1 @@ -# Rust-chat-server - -A Chat server writen in rust to allow communication between peers. - ---- - -## Features: -- implemented: - - json based API. - - Server introspection. - - Peer discovery. - - sending messages to connected clients. - - -- todo: - - Encryption to server. - - server to server meshing. - - asynchronous client managment instead of threaded approach. - -## Goals: -- Learn the rust programming lanaguage. - - Ownership: how that affects normal programming styles. - - Borrowing and references: how this affects shared state. - - Lifetimes: how this affects data retention and sharing. -- Learn how to create networked programs. - - Application level protocol: how to get two programs to communicate via TCP sockets. - - Socket handling: Discovering ways to handle multiple socket connections without affecting performance. -- Learn common encryption protocols. - - Adding support for encrypted sockets. - - Pros and cons of symetric and asymetric encryption. - - resolving common encryption flaws - -> Questions: For questions please add a issue with the question label. It will eventually be responded to +# rust-chat-server \ No newline at end of file diff --git a/foundation/Cargo.toml b/foundation/Cargo.toml index a20e1ab..a0c361d 100644 --- a/foundation/Cargo.toml +++ b/foundation/Cargo.toml @@ -9,17 +9,14 @@ edition = "2018" [dependencies] regex = "1" -crossbeam = "0.8.0" -crossbeam-channel = "0.5.0" -crossbeam-queue = "0.3.1" -parking_lot = "0.11.1" -dashmap = "4.0.2" -rayon = "1.3.1" -zeroize = "1.1.0" -crossterm = "0.19.0" -log = "0.4" url = "2.2.0" -uuid = {version = "0.8", features = ["serde", "v4"]} -serde = { version = "1.0", features = ["derive"] } +openssl = "0.10" +base64 = "0.13.0" +zeroize = "1.1.0" serde_json = "1.0" +futures = "0.3.16" +async-trait = "0.1.51" +serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.9.0", features = ["full"] } +uuid = {version = "0.8", features = ["serde", "v4"]} \ No newline at end of file diff --git a/foundation/src/encryption/helpers.rs b/foundation/src/encryption/helpers.rs new file mode 100644 index 0000000..4209cae --- /dev/null +++ b/foundation/src/encryption/helpers.rs @@ -0,0 +1,30 @@ +use openssl::derive::Deriver; +use openssl::ec::EcGroup; +use openssl::ec::EcKey; +use openssl::nid::Nid; +use openssl::pkey::PKey; + +pub fn create_test_shared() -> Vec { + let ec_group1 = EcGroup::from_curve_name(Nid::SECP256K1).unwrap(); + let ec_group2 = EcGroup::from_curve_name(Nid::SECP256K1).unwrap(); + + let eckey1 = EcKey::generate(ec_group1.as_ref()).unwrap(); + let eckey2 = EcKey::generate(ec_group2.as_ref()).unwrap(); + + let pkey1 = PKey::from_ec_key(eckey1).unwrap(); + let pkey2 = PKey::from_ec_key(eckey2).unwrap(); + + let pem1 = pkey1.public_key_to_pem().unwrap(); + let pem2 = pkey2.public_key_to_pem().unwrap(); + + let pub1 = PKey::public_key_from_pem(&pem1).unwrap(); + let pub2 = PKey::public_key_from_pem(&pem2).unwrap(); + + let mut deriver1 = Deriver::new(pkey1.as_ref()).expect("deriver1 failed"); + let mut deriver2 = Deriver::new(pkey2.as_ref()).expect("deriver2 failed"); + + deriver1.set_peer(pub2.as_ref()).unwrap(); + deriver2.set_peer(pub1.as_ref()).unwrap(); + + deriver1.derive_to_vec().unwrap() +} \ No newline at end of file diff --git a/foundation/src/encryption/mod.rs b/foundation/src/encryption/mod.rs new file mode 100644 index 0000000..0196d52 --- /dev/null +++ b/foundation/src/encryption/mod.rs @@ -0,0 +1,97 @@ +pub mod helpers; + +use crate::prelude::TransformerFn; +use openssl::symm::{Cipher, Crypter, Mode}; + +#[allow(clippy::clone_on_copy)] +pub fn create_encryption_transformers( + key: Vec, + iv: &[u8; 32], +) -> (TransformerFn, TransformerFn) { + // clone vecs + let key1 = key.clone(); + let key2 = key.clone(); + + let iv1 = iv.clone(); + let iv2 = iv.clone(); + + ( + Box::new(move |plain_text| { + println!("[encryptor_fn] plain_text:{:?}", plain_text); + let encrypter = Crypter::new(Cipher::aes_256_gcm(), Mode::Encrypt, &key1, Some(&iv1)); + let mut ciphertext = vec![0u8; 128]; + let _cipherlen = encrypter + .unwrap() + .update(plain_text, &mut ciphertext) + .unwrap(); + ciphertext + }), + Box::new(move |cipher_text| { + println!("[decryptor_fn] cipher_text:{:?}", cipher_text); + let decrypter = Crypter::new(Cipher::aes_256_gcm(), Mode::Decrypt, &key2, Some(&iv2)); + let mut plain_text = vec![0u8; 128]; + decrypter + .unwrap() + .update(cipher_text, &mut plain_text) + .unwrap(); + plain_text + }), + ) +} + +#[cfg(test)] +mod test { + use openssl::sha::sha256; + use openssl::symm::{Cipher, Crypter, Mode}; + + use super::create_encryption_transformers; + use super::helpers::create_test_shared; + + #[test] + pub fn test_transformer_functions() { + let shared = create_test_shared(); + + let (en, de) = create_encryption_transformers(shared, b"12345678901234561234561234567765"); + + let message = b"Hello world"; + + let cipher_text = (*en)(message); + + assert_ne!(&cipher_text[0..message.len()], message); + + let decrypted_text = (*de)(&cipher_text); + + assert_eq!(&decrypted_text[0..message.len()], message); + } + + #[test] + pub fn test_aes() { + let shared = create_test_shared(); + + let plaintext = b"This is a message"; + let key = sha256(&shared); + let iv = b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + + let encrypter = Crypter::new(Cipher::aes_256_gcm(), Mode::Encrypt, &key, Some(iv)); + let mut ciphertext = vec![0u8; 1024]; + let cipherlen = encrypter + .unwrap() + .update(plaintext, &mut ciphertext) + .unwrap(); + + let decrypter = Crypter::new(Cipher::aes_256_gcm(), Mode::Decrypt, &key, Some(iv)); + let mut decrypted = vec![0u8; 1024]; + decrypter + .unwrap() + .update(&ciphertext[..cipherlen], &mut decrypted) + .unwrap(); + + println!("plaintext: {:?}", plaintext); + println!("ciphertext: {:?}", &ciphertext[0..plaintext.len()]); + println!("decryptedtext: {:?}", &decrypted[0..plaintext.len()]); + + let test: &[u8] = &decrypted; + + assert_eq!(&test[0..plaintext.len()], plaintext); + } +} diff --git a/foundation/src/helpers.rs b/foundation/src/helpers.rs new file mode 100644 index 0000000..93e838e --- /dev/null +++ b/foundation/src/helpers.rs @@ -0,0 +1,115 @@ +use std::pin::Pin; +use std::io::Error; +use std::task::Poll; +use std::sync::Mutex; +use std::task::Context; + +use tokio::io::ReadBuf; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub struct BufferStream { + buffer: Mutex>, +} + +impl BufferStream { + pub fn new() -> BufferStream { + BufferStream { + buffer: Mutex::new(Vec::new()), + } + } +} + +impl AsyncRead for BufferStream { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context, + buf: &mut ReadBuf<'_> + ) -> Poll> { + let mut lock = self.buffer.lock().unwrap(); + + let a = if buf.remaining() < lock.len() {buf.remaining()} else {lock.len()}; + + buf.put_slice(&lock[..a]); + + *lock = Vec::from(&lock[a..]); + + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for BufferStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8] + ) -> Poll> { + let mut lock = self.buffer.lock().unwrap(); + lock.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod test { + + use tokio::io::split; + use tokio::io::AsyncWriteExt; + + use crate::helpers::BufferStream; + use tokio::io::AsyncReadExt; + + #[tokio::test] + async fn test_reading_and_writing() { + let stream = BufferStream::new(); + + let (mut rd, mut wd) = split(stream); + + let _ = wd.write_all(b"1010").await; + + let mut buf: [u8; 4] = [0; 4]; + + let _ = rd.read(&mut buf[..]).await; + + println!("[test_reading_and_writing] {:?}", &buf[..]); + + assert_eq!(b"1010", &buf[..]); + } + + + #[tokio::test] + async fn test_reading_small() { + let stream = BufferStream::new(); + + let (mut rd, mut wd) = split(stream); + + let _ = wd.write_all(b"10100101").await; + + let mut buf: [u8; 4] = [0; 4]; + + let _ = rd.read(&mut buf[..]).await; + + println!("[test_reading_and_writing] {:?}", &buf[..]); + + assert_eq!(b"1010", &buf[..]); + + let _ = rd.read(&mut buf[..]).await; + + println!("[test_reading_and_writing] {:?}", &buf[..]); + + assert_eq!(b"0101", &buf[..]); + } +} \ No newline at end of file diff --git a/foundation/src/lib.rs b/foundation/src/lib.rs index 3ff3748..1458ccf 100644 --- a/foundation/src/lib.rs +++ b/foundation/src/lib.rs @@ -1,12 +1,24 @@ +pub mod encryption; pub mod messages; pub mod prelude; +pub mod network; +pub mod helpers; use serde::{Deserialize, Serialize}; use uuid::Uuid; +/** + * #ClientDetails. + * This defines the fileds a client would want to send when connecitng + * uuid: the unique id of the user. + * username: the users user name. + * address: the ip address of the connected user. + * public_key: the public key used when sending messages to the user. + */ #[derive(Deserialize, Serialize, Debug, Clone)] pub struct ClientDetails { - pub uuid: Uuid, - pub username: String, - pub address: String, -} \ No newline at end of file + pub uuid: Uuid, + pub username: String, + pub address: String, + pub public_key: Option>, +} diff --git a/foundation/src/messages/client.rs b/foundation/src/messages/client.rs index cabe3bc..dfb603b 100644 --- a/foundation/src/messages/client.rs +++ b/foundation/src/messages/client.rs @@ -7,7 +7,8 @@ use uuid::Uuid; /// This enum defined the message that a client can receive from the server /// This uses the serde library to transform to and from json. /// -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] +#[serde(tag = "type")] pub enum ClientStreamIn { Connected, @@ -18,14 +19,17 @@ pub enum ClientStreamIn { Disconnect, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] +#[serde(tag = "type")] pub enum ClientStreamOut { Connected, UserMessage { from: Uuid, content: String }, GlobalMessage { content: String }, - ConnectedClients {clients: Vec}, + ConnectedClients { clients: Vec }, Disconnected, + + Error, } diff --git a/foundation/src/messages/network.rs b/foundation/src/messages/network.rs index 98a2683..6a14abc 100644 --- a/foundation/src/messages/network.rs +++ b/foundation/src/messages/network.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] +#[serde(tag = "type")] pub enum NetworkSockIn { Info, Connect { @@ -10,7 +11,8 @@ pub enum NetworkSockIn { }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] +#[serde(tag = "type")] pub enum NetworkSockOut<'a> { Request, diff --git a/foundation/src/network/mod.rs b/foundation/src/network/mod.rs new file mode 100644 index 0000000..5c9bf4b --- /dev/null +++ b/foundation/src/network/mod.rs @@ -0,0 +1,184 @@ +use tokio::io::AsyncWrite; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use std::fmt::Debug; +use std::io::Error; +use std::io::Write; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use serde::Serialize; +use tokio::io::split; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; +use tokio::io::{BufReader}; +use tokio::sync::Mutex; + +use crate::prelude::StreamMessageSender; +use crate::prelude::TransformerFn; + +type TransformerVec = Vec; + +pub struct SocketHandler +where + T: AsyncRead + AsyncWrite + Send +{ + stream_tx: Mutex>, + stream_rx: Mutex>>, + + send_transformer: Mutex, + recv_transformer: Mutex, +} + +impl SocketHandler +where + T: AsyncReadExt + AsyncWriteExt + Send +{ + pub fn new(connection: T) -> Arc { + let (rd, wd) = split(connection); + let reader = BufReader::new(rd); + + Arc::new(SocketHandler { + stream_tx: Mutex::new(wd), + stream_rx: Mutex::new(reader), + + send_transformer: Mutex::new(Vec::new()), + recv_transformer: Mutex::new(Vec::new()), + }) + } + + pub async fn push_layer(self: &Arc, send_func: TransformerFn, recv_func: TransformerFn) { + let mut send_lock = self.send_transformer.lock().await; + let mut recv_lock = self.recv_transformer.lock().await; + send_lock.push(send_func); + recv_lock.push(recv_func); + } + + pub async fn pop_layer(self: &Arc) { + let mut send_lock = self.send_transformer.lock().await; + let mut recv_lock = self.recv_transformer.lock().await; + let _ = send_lock.pop(); + let _ = recv_lock.pop(); + } +} + +#[async_trait] +impl StreamMessageSender for SocketHandler +where + T: AsyncReadExt + AsyncWriteExt + Send +{ + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), Error> { + let message_string = serde_json::to_string(&message)?; + let mut out_buffer = Vec::from(message_string); + let message_length = out_buffer.len(); + println!("[SocketHandler:send] message_length:{:?}", &message_length); + + println!("[SocketHandler:send] message_before: {:?}", &out_buffer); + + let transformers = self.send_transformer.lock().await; + let iter = transformers.iter(); + + for func in iter { + let transform = (**func)(&out_buffer); + out_buffer.clear(); + out_buffer.extend_from_slice(&transform); + } + + let data = base64::encode(&out_buffer[..message_length]); + + println!("[SocketHandler:send] message_encode_base64: {:?}", &data); + + out_buffer.clear(); + + writeln!(out_buffer, "{}", data)?; + + println!("[SocketHandler:send] message_out: {:?}", &out_buffer); + + let mut lock = self.stream_tx.lock().await; + lock.write_all(&out_buffer).await?; + lock.flush().await?; + Ok(()) + } + + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result { + let mut in_buffer = String::new(); + let mut lock = self.stream_rx.lock().await; + let mut length = lock.read_line(&mut in_buffer).await.unwrap(); + in_buffer.pop(); + println!("[SocketHandler:recv] message_in: {:?}", &in_buffer); + + let mut in_buffer = base64::decode(in_buffer).unwrap(); + println!("[SocketHandler:recv] message_decoded_base64: {:?}", &in_buffer); + + length = in_buffer.len(); + + let transformers = self.recv_transformer.lock().await; + let iter = transformers.iter().rev(); + for func in iter { + let transform = (**func)(&in_buffer); + in_buffer.clear(); + in_buffer.extend_from_slice(&transform[..length]); + } + println!("[SocketHandler:recv] message_after_transoformed: {:?}", &in_buffer); + + let in_buffer = String::from_utf8(in_buffer).unwrap(); + let message: TInMessage = serde_json::from_str(&in_buffer).unwrap(); + Ok(message) + } +} + +impl Debug for SocketHandler +where + T: AsyncReadExt + AsyncWriteExt + Send +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "[SocketSender]") + } +} + +#[cfg(test)] +mod test { + use crate::helpers::BufferStream; + use super::SocketHandler; + + use crate::prelude::StreamMessageSender; + use crate::encryption::helpers::create_test_shared; + use crate::encryption::create_encryption_transformers; + + #[tokio::test] + async fn test_socket_sender() { + + let stream = BufferStream::new(); + + let handle = SocketHandler::new(stream); + let _ = handle.send::(true).await.unwrap(); + let message = handle.recv::().await.unwrap(); + + assert!(message); + } + + #[tokio::test] + async fn test_socket_sender_with_encryption() { + + let stream = BufferStream::new(); + + let shared = create_test_shared(); + let (en, de) = create_encryption_transformers(shared, b"12345678901234567890123456789011"); + let handle = SocketHandler::new(stream); + + handle.push_layer(en, de).await; + + handle.send::(true).await.unwrap(); + let message = handle.recv::().await.unwrap(); + + assert!(message); + } +} diff --git a/foundation/src/prelude.rs b/foundation/src/prelude.rs index 92ff1d7..6709dbe 100644 --- a/foundation/src/prelude.rs +++ b/foundation/src/prelude.rs @@ -1,15 +1,19 @@ use std::sync::Arc; -pub trait IMessagable { - fn send_message(&self, msg: TMessage); - fn set_sender(&self, sender: TSender); -} +use async_trait::async_trait; -pub trait ICooperative { - fn tick(&self); -} +use serde::de::DeserializeOwned; +use serde::Serialize; -pub trait IPreemptive { - fn run(arc: &Arc); - fn start(arc: &Arc); -} +pub type TransformerFn = Box Vec + Send + Sync>; + +#[async_trait] +pub trait StreamMessageSender { + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), std::io::Error>; + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result; +} \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml index 779de58..544999c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,2 @@ hard_tabs = true -max_width = 90 \ No newline at end of file +max_width = 100 \ No newline at end of file diff --git a/server/Cargo.toml b/server/Cargo.toml index a91dd21..8ac20d3 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -13,6 +13,10 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" crossbeam = "0.8.0" crossbeam-channel = "0.5.0" +zeroize = "1.1.0" +openssl = "0.10.33" +tokio = { version = "1.9.0", features = ["full"] } +futures = "0.3.16" +async-trait = "0.1.51" -[dependencies.foundation] -path = '../foundation' \ No newline at end of file +foundation = {path = '../foundation'} \ No newline at end of file diff --git a/server/src/client.rs b/server/src/client.rs index d5e1efd..278feaf 100644 --- a/server/src/client.rs +++ b/server/src/client.rs @@ -1,60 +1,42 @@ -use crate::messages::ClientMessage; -use crate::messages::ServerMessage; -use foundation::prelude::IPreemptive; +use tokio::net::TcpStream; use std::cmp::Ordering; -use std::io::BufRead; -use std::io::Write; -use std::io::{BufReader, BufWriter}; -use std::mem::replace; -use std::net::TcpStream; use std::sync::Arc; -use std::sync::Mutex; -use crossbeam_channel::{unbounded, Receiver, Sender}; -use serde::Serialize; use uuid::Uuid; +use futures::lock::Mutex; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + use foundation::ClientDetails; +use foundation::network::SocketHandler; +use foundation::prelude::StreamMessageSender; use foundation::messages::client::{ClientStreamIn, ClientStreamOut}; -use foundation::prelude::IMessagable; + +use crate::messages::ClientMessage; +use crate::messages::ServerMessage; /// # Client /// This struct represents a connected user. /// -/// ## Attrubutes -/// - uuid: The id of the connected user. -/// - username: The username of the connected user. -/// - address: The the address of the connected client. +/// ## Attributes +/// - details: store of the clients infomation. /// /// - stream: The socket for the connected client. /// - stream_reader: the buffered reader used to receive messages /// - stream_writer: the buffered writer used to send messages /// - owner: An optional reference to the owning object. -#[derive(Debug, Serialize)] +#[derive(Debug)] pub struct Client { - pub uuid: Uuid, - username: String, - address: String, pub details: ClientDetails, - // non serializable - #[serde(skip)] - server_channel: Mutex>>, + // server send channel + server_channel: Mutex>, - #[serde(skip)] - input: Sender, + // object channels + tx: Sender, + rx: Mutex>, - #[serde(skip)] - output: Receiver, - - #[serde(skip)] - stream: Mutex>, - - #[serde(skip)] - stream_reader: Mutex>>, - - #[serde(skip)] - stream_writer: Mutex>>, + socket_sender: Arc>, } // client funciton implmentations @@ -63,199 +45,150 @@ impl Client { uuid: String, username: String, address: String, - stream: TcpStream, + socket_sender: Arc>, server_channel: Sender, ) -> Arc { - let (sender, receiver) = unbounded(); - - let out_stream = stream.try_clone().unwrap(); - let in_stream = stream.try_clone().unwrap(); + let (sender, receiver) = channel(1024); Arc::new(Client { - username: username.clone(), - uuid: Uuid::parse_str(&uuid).expect("invalid id"), - address: address.clone(), - details: ClientDetails { uuid: Uuid::parse_str(&uuid).expect("invalid id"), username, address, + public_key: None, }, - server_channel: Mutex::new(Some(server_channel)), + server_channel: Mutex::new(server_channel), + socket_sender, - input: sender, - output: receiver, - - stream: Mutex::new(Some(stream)), - - stream_reader: Mutex::new(Some(BufReader::new(in_stream))), - stream_writer: Mutex::new(Some(BufWriter::new(out_stream))), + tx: sender, + rx: Mutex::new(receiver), }) } -} -impl IMessagable> for Client { - fn send_message(&self, msg: ClientMessage) { - self.input - .send(msg) - .expect("failed to send message to client."); - } - fn set_sender(&self, sender: Sender) { - let mut server_lock = self.server_channel.lock().unwrap(); - let _ = replace(&mut *server_lock, Some(sender)); - } -} + pub fn start(self: &Arc) { + let t1_client = self.clone(); + let t2_client = self.clone(); -// cooperative multitasking implementation -impl IPreemptive for Client { - fn run(arc: &Arc) { - let arc1 = arc.clone(); - let arc2 = arc.clone(); + // client stream read task + tokio::spawn(async move { + use ClientMessage::Disconnect; - // read thread - let _ = std::thread::Builder::new() - .name(format!("client thread recv [{:?}]", &arc.uuid)) - .spawn(move || { - use ClientMessage::{Disconnect}; - let arc = arc1; + let client = t1_client; - let mut buffer = String::new(); - let mut reader_lock = arc.stream_reader.lock().unwrap(); - let reader = reader_lock.as_mut().unwrap(); + client + .socket_sender + .send::(ClientStreamOut::Connected) + .await + .expect("error"); - 'main: while let Ok(size) = reader.read_line(&mut buffer) { - if size == 0 { - arc.send_message(Disconnect); - break 'main; + loop { + let command = client.socket_sender.recv::().await; + match command { + Ok(ClientStreamIn::Disconnect) => { + println!("[Client {:?}]: Disconnect recieved", &client.details.uuid); + client.send_message(Disconnect).await; + return; } - - let command = serde_json::from_str::(buffer.as_str()); - match command { - Ok(ClientStreamIn::Disconnect) => { - println!("[Client {:?}]: Disconnect recieved", &arc.uuid); - arc.send_message(Disconnect); - break 'main; - } - Ok(ClientStreamIn::SendMessage { to, content }) => { - println!( - "[Client {:?}]: send message to: {:?}", - &arc.uuid, &to - ); - let lock = arc.server_channel.lock().unwrap(); - let sender = lock.as_ref().unwrap(); - let _ = sender.send(ServerMessage::ClientSendMessage { - from: arc.uuid, + Ok(ClientStreamIn::SendMessage { to, content }) => { + println!( + "[Client {:?}]: send message to: {:?}", + &client.details.uuid, &to + ); + let lock = client.server_channel.lock().await; + let _ = lock + .send(ServerMessage::ClientSendMessage { + from: client.details.uuid, to, content, - }); - } - _ => println!("[Client {:?}]: command not found", &arc.uuid), + }) + .await; + } + Ok(ClientStreamIn::Update) => { + println!("[Client {:?}]: update received", &client.details.uuid); + let lock = client.server_channel.lock().await; + let _ = lock + .send(ServerMessage::ClientUpdate { + to: client.details.uuid, + }) + .await; + } + _ => { + println!("[Client {:?}]: command not found", &client.details.uuid); + let lock = client.server_channel.lock().await; + let _ = lock + .send(ServerMessage::ClientError { + to: client.details.uuid, + }) + .await; } } - println!("[Client {:?}] exited thread 1", &arc.uuid); - }); + } + }); - // write thread - let _ = std::thread::Builder::new() - .name(format!("client thread msg [{:?}]", &arc.uuid)) - .spawn(move || { - let arc = arc2; - let mut writer_lock = arc.stream_writer.lock().unwrap(); - let writer = writer_lock.as_mut().unwrap(); - let mut buffer: Vec = Vec::new(); + // client channel read thread + tokio::spawn(async move { + use ClientMessage::{Disconnect, Error, Message, SendClients}; - let _ = writeln!( - buffer, - "{}", - serde_json::to_string(&ClientStreamOut::Connected).unwrap() - ); - let _ = writer.write_all(&buffer); - let _ = writer.flush(); + let client = t2_client; - 'main: loop { - for message in arc.output.iter() { - use ClientMessage::{Disconnect,Message, Update}; - println!("[Client {:?}]: {:?}", &arc.uuid, message); - match message { - Disconnect => { - arc.server_channel - .lock() - .unwrap() - .as_mut() - .unwrap() - .send(ServerMessage::ClientDisconnected(arc.uuid)) - .unwrap(); - break 'main; - } - Message { from, content } => { - let _ = writeln!( - buffer, - "{}", - serde_json::to_string( - &ClientStreamOut::UserMessage { from, content } - ) - .unwrap() - ); - let _ = writer.write_all(&buffer); - let _ = writer.flush(); - } - Update {clients} => { - let client_details_vec: Vec = clients.iter().map(|client| &client.details).cloned().collect(); - let _ = writeln!( - buffer, - "{}", - serde_json::to_string( - &ClientStreamOut::ConnectedClients {clients: client_details_vec} - ).unwrap() - ); - let _ = writer.write_all(&buffer); - let _ = writer.flush(); - } - } + loop { + let mut channel = client.rx.lock().await; + + let message = channel.recv().await.unwrap(); + drop(channel); + + println!("[Client {:?}]: {:?}", &client.details.uuid, message); + match message { + Disconnect => { + let lock = client.server_channel.lock().await; + let _ = lock + .send(ServerMessage::ClientDisconnected { + id: client.details.uuid, + }) + .await; + return; } + Message { from, content } => client + .socket_sender + .send::(ClientStreamOut::UserMessage { from, content }) + .await + .expect("error sending message"), + + SendClients { clients } => { + let client_details_vec: Vec = clients + .iter() + .map(|client| &client.details) + .cloned() + .collect(); + + client + .socket_sender + .send::(ClientStreamOut::ConnectedClients { + clients: client_details_vec, + }) + .await + .expect("error sending message"); + } + Error => client + .socket_sender + .send::(ClientStreamOut::Error) + .await + .expect("error sending message"), } - println!("[Client {:?}]: exited thread 2", &arc.uuid); - }); + } + }); } - fn start(arc: &Arc) { - Client::run(arc) - } -} - -// default value implementation -impl Default for Client { - fn default() -> Self { - let (sender, reciever) = unbounded(); - Client { - username: "generic_client".to_string(), - uuid: Uuid::new_v4(), - address: "127.0.0.1".to_string(), - - details: ClientDetails { - uuid: Uuid::new_v4(), - username: "generic_client".to_string(), - address: "127.0.0.1".to_string(), - }, - - output: reciever, - input: sender, - - server_channel: Mutex::new(None), - - stream: Mutex::new(None), - - stream_reader: Mutex::new(None), - stream_writer: Mutex::new(None), - } + pub async fn send_message(self: &Arc, msg: ClientMessage) { + let _ = self.tx.send(msg).await; } } // MARK: - used for sorting. impl PartialEq for Client { fn eq(&self, other: &Self) -> bool { - self.uuid == other.uuid + self.details.uuid == other.details.uuid } } @@ -269,7 +202,7 @@ impl PartialOrd for Client { impl Ord for Client { fn cmp(&self, other: &Self) -> Ordering { - self.uuid.cmp(&other.uuid) + self.details.uuid.cmp(&other.details.uuid) } } diff --git a/server/src/client_manager.rs b/server/src/client_manager.rs index a3f4d51..7e8d8c3 100644 --- a/server/src/client_manager.rs +++ b/server/src/client_manager.rs @@ -1,18 +1,14 @@ -// use crate::lib::server::ServerMessages; -use foundation::prelude::IPreemptive; use std::collections::HashMap; -use std::mem::replace; use std::sync::Arc; -use std::sync::Mutex; -use crossbeam_channel::{unbounded, Receiver, Sender}; +use futures::lock::Mutex; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use uuid::Uuid; use crate::client::Client; use crate::messages::ClientMessage; use crate::messages::ClientMgrMessage; use crate::messages::ServerMessage; -use foundation::prelude::IMessagable; /// # ClientManager /// This struct manages all connected users @@ -22,93 +18,89 @@ pub struct ClientManager { server_channel: Mutex>, - sender: Sender, - receiver: Receiver, + tx: Sender, + rx: Mutex>, } impl ClientManager { pub fn new(server_channel: Sender) -> Arc { - let (sender, receiver) = unbounded(); + let (tx, rx) = channel(1024); Arc::new(ClientManager { clients: Mutex::default(), server_channel: Mutex::new(server_channel), - sender, - receiver, + tx, + rx: Mutex::new(rx), }) } -} -impl IMessagable> for ClientManager { - fn send_message(&self, msg: ClientMgrMessage) { - self.sender.send(msg).unwrap(); - } - fn set_sender(&self, sender: Sender) { - let mut server_lock = self.server_channel.lock().unwrap(); - let _ = replace(&mut *server_lock, sender); - } -} + pub fn start(self: &Arc) { + let client_manager = self.clone(); -impl IPreemptive for ClientManager { - fn run(arc: &Arc) { - loop { - std::thread::sleep(std::time::Duration::from_secs(1)); + tokio::spawn(async move { + use ClientMgrMessage::{Add, Remove, SendClients, SendError, SendMessage}; - if !arc.receiver.is_empty() { - for message in arc.receiver.try_iter() { - println!("[Client manager]: recieved message: {:?}", message); - use ClientMgrMessage::{Add, Remove, SendMessage, SendClients}; + loop { + let mut receiver = client_manager.rx.lock().await; + let message = receiver.recv().await.unwrap(); - match message { - Add(client) => { - println!("[Client Manager]: adding new client"); - Client::start(&client); - let mut lock = arc.clients.lock().unwrap(); - if lock.insert(client.uuid, client).is_none() { - println!("value is new"); - } - }, - Remove(uuid) => { - println!("[Client Manager]: removing client: {:?}", &uuid); - if let Some(client) = - arc.clients.lock().unwrap().remove(&uuid) - { - client.send_message(ClientMessage::Disconnect); - } - }, - SendMessage { to, from, content } => { - let lock = arc.clients.lock().unwrap(); - if let Some(client) = lock.get(&to) { - client.send_message(ClientMessage::Message { - from, - content, - }) - } - }, - SendClients {to} => { - let lock = arc.clients.lock().unwrap(); - if let Some(client) = lock.get(&to) { - let clients_vec: Vec> = lock.values().cloned().collect(); + println!("[Client manager]: recieved message: {:?}", message); - client.send_message(ClientMessage::Update { + match message { + Add(client) => { + println!("[Client Manager]: adding new client"); + client.start(); + let mut lock = client_manager.clients.lock().await; + if lock.insert(client.details.uuid, client).is_none() { + println!("value is new"); + } + } + Remove(uuid) => { + println!("[Client Manager]: removing client: {:?}", &uuid); + if let Some(client) = client_manager.clients.lock().await.remove(&uuid) { + client.send_message(ClientMessage::Disconnect).await; + } + } + SendMessage { to, from, content } => { + client_manager + .send_to_client(&to, ClientMessage::Message { from, content }) + .await; + } + SendClients { to } => { + let lock = client_manager.clients.lock().await; + if let Some(client) = lock.get(&to) { + let clients_vec: Vec> = lock.values().cloned().collect(); + + client + .send_message(ClientMessage::SendClients { clients: clients_vec, }) - } - }, - - - #[allow(unreachable_patterns)] - _ => println!("[Client manager]: not implemented"), + .await + } } + SendError { to } => { + let lock = client_manager.clients.lock().await; + if let Some(client) = lock.get(&to) { + client.send_message(ClientMessage::Error).await + } + } + #[allow(unreachable_patterns)] + _ => println!("[Client manager]: not implemented"), } } + }); + } + + async fn send_to_client(self: &Arc, id: &Uuid, msg: ClientMessage) { + let lock = self.clients.lock().await; + if let Some(client) = lock.get(id) { + client.clone().send_message(msg).await; } } - fn start(arc: &Arc) { - let arc = arc.clone(); - std::thread::spawn(move || ClientManager::run(&arc)); + pub async fn send_message(self: Arc, message: ClientMgrMessage) { + let _ = self.tx.send(message).await; } } diff --git a/server/src/main.rs b/server/src/main.rs index dfc409f..b6b69b4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,29 +1,36 @@ pub mod client; pub mod client_manager; -pub mod messages; pub mod network_manager; +pub mod prelude; pub mod server; +pub mod messages; + +use std::io; use clap::{App, Arg}; -use foundation::prelude::IPreemptive; use server::Server; -fn main() { +#[tokio::main] +async fn main() -> io::Result<()> { let _args = App::new("--rust chat server--") - .version("0.1.5") - .author("Mitchel Hardie , Michael Bailey ") - .about("this is a chat server developed in rust, depending on the version one of two implementations will be used") - .arg( - Arg::with_name("config") - .short("p") - .long("port") - .value_name("PORT") - .help("sets the port the server runs on.") - .takes_value(true)) - .get_matches(); + .version("0.1.5") + .author("Mitchel Hardie , Michael Bailey ") + .about( + "this is a chat server developed in rust, depending on the version one of two implementations will be used", + ) + .arg( + Arg::with_name("config") + .short("p") + .long("port") + .value_name("PORT") + .help("sets the port the server runs on.") + .takes_value(true), + ) + .get_matches(); - let server = Server::new(); + let server = Server::new().unwrap(); - Server::run(&server); + server.start().await; + Ok(()) } diff --git a/server/src/messages.rs b/server/src/messages.rs index f5d2e11..b9306c2 100644 --- a/server/src/messages.rs +++ b/server/src/messages.rs @@ -7,31 +7,47 @@ use crate::client::Client; pub enum ClientMessage { Message { from: Uuid, content: String }, - Update {clients: Vec>}, + SendClients { clients: Vec> }, Disconnect, + + Error, } #[derive(Debug)] pub enum ClientMgrMessage { Remove(Uuid), Add(Arc), - SendClients {to: Uuid}, + SendClients { + to: Uuid, + }, SendMessage { from: Uuid, to: Uuid, content: String, }, + SendError { + to: Uuid, + }, } #[derive(Debug)] pub enum ServerMessage { - ClientConnected(Arc), + ClientConnected { + client: Arc, + }, ClientSendMessage { from: Uuid, to: Uuid, content: String, }, - ClientDisconnected(Uuid), - ClientUpdate(Uuid), + ClientDisconnected { + id: Uuid, + }, + ClientUpdate { + to: Uuid, + }, + ClientError { + to: Uuid, + }, } diff --git a/server/src/network/mod.rs b/server/src/network/mod.rs new file mode 100644 index 0000000..b6e420f --- /dev/null +++ b/server/src/network/mod.rs @@ -0,0 +1,188 @@ +use std::fmt::Debug; +use std::io::Error; +use std::io::Write; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use serde::Serialize; +use tokio::io::split; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; +use tokio::sync::Mutex; + +use crate::prelude::StreamMessageSender; +use crate::prelude::TransformerFn; + +type TransformerVec = Vec; + +pub struct SocketHandler { + stream_tx: Mutex>, + stream_rx: Mutex>>, + + send_transformer: Mutex, + recv_transformer: Mutex, +} + +impl SocketHandler { + pub fn new(connection: TcpStream) -> Arc { + let (rd, wd) = split(connection); + let reader = BufReader::new(rd); + + Arc::new(SocketHandler { + stream_tx: Mutex::new(wd), + stream_rx: Mutex::new(reader), + + send_transformer: Mutex::new(Vec::new()), + recv_transformer: Mutex::new(Vec::new()), + }) + } + + pub async fn push_layer(self: &Arc, send_func: TransformerFn, recv_func: TransformerFn) { + let mut send_lock = self.send_transformer.lock().await; + let mut recv_lock = self.recv_transformer.lock().await; + send_lock.push(send_func); + recv_lock.push(recv_func); + } + + pub async fn pop_layer(self: &Arc) { + let mut send_lock = self.send_transformer.lock().await; + let mut recv_lock = self.recv_transformer.lock().await; + let _ = send_lock.pop(); + let _ = recv_lock.pop(); + } +} + +#[async_trait] +impl StreamMessageSender for SocketHandler { + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), Error> { + let mut out_buffer: Vec = Vec::new(); + let message_string = serde_json::to_string(&message)?; + writeln!(out_buffer, "{}", message_string)?; + + println!("[SocketHandler:send] message_before: {:?}", &out_buffer); + + let transformers = self.send_transformer.lock().await; + let iter = transformers.iter(); + + for func in iter { + out_buffer = (**func)(&out_buffer); + } + + println!("[SocketHandler:send] message_after: {:?}", &out_buffer); + + let mut lock = self.stream_tx.lock().await; + lock.write_all(&out_buffer).await?; + lock.flush().await?; + Ok(()) + } + + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result { + let mut in_buffer = String::new(); + let mut lock = self.stream_rx.lock().await; + lock.read_line(&mut in_buffer).await?; + + println!("[SocketHandler:recv] message_before: {:?}", &in_buffer); + + let transformers = self.recv_transformer.lock().await; + let iter = transformers.iter(); + + let mut in_buffer = in_buffer.into_bytes(); + + for func in iter { + in_buffer = (**func)(&in_buffer); + } + + println!("[SocketHandler:recv] message_after: {:?}", &in_buffer); + + let in_buffer = String::from_utf8(in_buffer).expect("invalid utf_8"); + + let message: TInMessage = serde_json::from_str(&in_buffer).unwrap(); + + Ok(message) + } +} + +impl Debug for SocketHandler { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "[SocketSender]") + } +} + +#[cfg(test)] +mod test { + use tokio::runtime::Runtime; + use std::sync::Once; + use std::time::Duration; + + use tokio::task; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::time::sleep; + + use super::SocketHandler; + use crate::helpers::start_server; + use crate::helpers::create_test_shared; + use crate::prelude::StreamMessageSender; + use crate::encryption::create_encryption_transformers; + + + static SERVER_INIT: Once = Once::new(); + + fn setup() { + SERVER_INIT.call_once(|| { + std::thread::spawn(|| { + let rt = Runtime::new().unwrap(); + rt.block_on(start_server()) + + }); + }) + } + + #[tokio::test] + async fn test_socket_sender() { + setup(); + task::spawn(start_server()); + + let socket = TcpStream::connect("localhost:5600") + .await + .expect("failed to connect"); + + let handle = SocketHandler::new(socket); + let _ = handle.send::(true).await; + let message = handle.recv::().await.unwrap(); + + assert!(message); + } + + #[tokio::test] + async fn test_socket_sender_with_encryption() { + setup(); + task::spawn(start_server()); + + let socket = TcpStream::connect("localhost:5600") + .await + .unwrap(); + + let shared = create_test_shared(); + let (en, de) = create_encryption_transformers(shared, b"12345678901234567890123456789011"); + let handle = SocketHandler::new(socket); + + handle.push_layer(en, de).await; + + let _ = handle.send::(true).await; + let message = handle.recv::().await.unwrap(); + + assert!(message); + } +} diff --git a/server/src/network_manager.rs b/server/src/network_manager.rs index 5e450ac..7329656 100644 --- a/server/src/network_manager.rs +++ b/server/src/network_manager.rs @@ -1,131 +1,78 @@ -use foundation::prelude::IPreemptive; -use std::io::BufRead; -use std::io::BufReader; -use std::io::BufWriter; -use std::io::Write; -use std::net::TcpListener; use std::sync::Arc; -use std::thread; -use crossbeam_channel::Sender; +use tokio::net::TcpListener; +use tokio::sync::mpsc::Sender; + +use foundation::prelude::StreamMessageSender; +use foundation::messages::network::{NetworkSockIn, NetworkSockOut}; +use foundation::network::SocketHandler; use crate::client::Client; use crate::messages::ServerMessage; -use foundation::messages::network::{NetworkSockIn, NetworkSockOut}; pub struct NetworkManager { - listener: TcpListener, + address: String, server_channel: Sender, } impl NetworkManager { - pub fn new( - port: String, - server_channel: Sender, - ) -> Arc { - let mut address = "0.0.0.0:".to_string(); - address.push_str(&port); - - let listener = TcpListener::bind(address).expect("Could not bind to address"); - + pub fn new(_port: String, server_channel: Sender) -> Arc { Arc::new(NetworkManager { - listener, + address: "0.0.0.0:5600".to_string(), server_channel, }) } -} -impl IPreemptive for NetworkManager { - fn run(_: &Arc) {} + pub fn start(self: &Arc) { + let network_manager = self.clone(); - fn start(arc: &Arc) { - let arc = arc.clone(); - std::thread::spawn(move || { - // fetch new connections and add them to the client queue - for connection in arc.listener.incoming() { - println!("[NetworkManager]: New Connection!"); - match connection { - Ok(stream) => { - let server_channel = arc.server_channel.clone(); + tokio::spawn(async move { + let listener = TcpListener::bind(network_manager.address.clone()) + .await + .unwrap(); - // create readers - let mut reader = BufReader::new(stream.try_clone().unwrap()); - let mut writer = BufWriter::new(stream.try_clone().unwrap()); + loop { + let (connection, _) = listener.accept().await.unwrap(); + let stream_sender = SocketHandler::new(connection); + let server_channel = network_manager.server_channel.clone(); - let _handle = thread::Builder::new() - .name("NetworkJoinThread".to_string()) - .spawn(move || { - let mut out_buffer: Vec = Vec::new(); - let mut in_buffer: String = String::new(); + tokio::spawn(async move { + stream_sender + .send::(NetworkSockOut::Request) + .await + .expect("failed to send message"); - // send request message to connection - - let _ = writeln!( - out_buffer, - "{}", - serde_json::to_string(&NetworkSockOut::Request) - .unwrap() + if let Ok(request) = stream_sender.recv::().await { + match request { + NetworkSockIn::Info => { + stream_sender + .send(NetworkSockOut::GotInfo { + server_name: "oof", + server_owner: "michael", + }) + .await + .expect("failed to send got info"); + } + NetworkSockIn::Connect { + uuid, + username, + address, + } => { + // create client and send to server + let new_client = Client::new( + uuid, + username, + address, + stream_sender, + server_channel.clone(), ); - - let _ = writer.write_all(&out_buffer); - let _ = writer.flush(); - - // try get response - let res = reader.read_line(&mut in_buffer); - if res.is_err() { - return; - } - - //match the response - if let Ok(request) = - serde_json::from_str::(&in_buffer) - { - match request { - NetworkSockIn::Info => { - // send back server info to the connection - writer - .write_all( - serde_json::to_string( - &NetworkSockOut::GotInfo { - server_name: "oof", - server_owner: "michael", - }, - ) - .unwrap() - .as_bytes(), - ) - .unwrap(); - writer.write_all(b"\n").unwrap(); - writer.flush().unwrap(); - } - NetworkSockIn::Connect { - uuid, - username, - address, - } => { - // create client and send to server - let new_client = Client::new( - uuid, - username, - address, - stream.try_clone().unwrap(), - server_channel.clone(), - ); - server_channel - .send(ServerMessage::ClientConnected( - new_client, - )) - .unwrap_or_default(); - } - } - } - }); + let _ = server_channel + .send(ServerMessage::ClientConnected { client: new_client }) + .await; + } + } } - Err(e) => { - println!("[Network manager]: error getting stream: {:?}", e); - continue; - } - } + }); } }); } diff --git a/server/src/prelude.rs b/server/src/prelude.rs new file mode 100644 index 0000000..1242c64 --- /dev/null +++ b/server/src/prelude.rs @@ -0,0 +1,19 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use serde::de::DeserializeOwned; +use serde::Serialize; + +#[async_trait] +pub trait StreamMessageSender { + async fn send( + self: &Arc, + message: TOutMessage, + ) -> Result<(), std::io::Error>; + async fn recv<'de, TInMessage: DeserializeOwned + Send>( + self: &Arc, + ) -> Result; +} + +pub type TransformerFn = Box Vec + Send + Sync>; diff --git a/server/src/server.rs b/server/src/server.rs index 2e7d7ec..cb26327 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,15 +1,14 @@ use std::sync::Arc; -use crossbeam_channel::{unbounded, Receiver}; +// use crossbeam_channel::{unbounded, Receiver}; +use futures::lock::Mutex; +use tokio::sync::mpsc::{channel, Receiver}; use uuid::Uuid; use crate::client_manager::ClientManager; use crate::messages::ClientMgrMessage; use crate::messages::ServerMessage; use crate::network_manager::NetworkManager; -use foundation::prelude::ICooperative; -use foundation::prelude::IMessagable; -use foundation::prelude::IPreemptive; /// # ServerMessages /// This is used internally to send messages to the server to be dispatched @@ -19,65 +18,79 @@ pub enum ServerMessages { ClientDisconnected(Uuid), } +/// # Server +/// authors: @michael-bailey, @Mitch161 +/// This Represents a server instance. +/// it is componsed of a client manager and a network manager +/// pub struct Server { client_manager: Arc, network_manager: Arc, - - receiver: Receiver, + receiver: Mutex>, } impl Server { - pub fn new() -> Arc { - let (sender, receiver) = unbounded(); + /// Create a new server object + pub fn new() -> Result, Box> { + let (sender, receiver) = channel(1024); - Arc::new(Server { + Ok(Arc::new(Server { client_manager: ClientManager::new(sender.clone()), - network_manager: NetworkManager::new("5600".to_string(), sender), - receiver, - }) + receiver: Mutex::new(receiver), + })) } -} -impl ICooperative for Server { - fn tick(&self) { + pub async fn start(self: &Arc) { + // start client manager and network manager + self.network_manager.clone().start(); + self.client_manager.clone().start(); + + // clone block items + let server = self.clone(); + use ClientMgrMessage::{Add, Remove, SendMessage}; - // handle new messages loop - if !self.receiver.is_empty() { - for message in self.receiver.try_iter() { + loop { + let mut lock = server.receiver.lock().await; + if let Some(message) = lock.recv().await { println!("[server]: received message {:?}", &message); + match message { - ServerMessage::ClientConnected(client) => { - self.client_manager.send_message(Add(client)) + ServerMessage::ClientConnected { client } => { + server + .client_manager + .clone() + .send_message(Add(client)) + .await } - ServerMessage::ClientDisconnected(uuid) => { - println!("disconnecting client {:?}", uuid); - self.client_manager.send_message(Remove(uuid)); + ServerMessage::ClientDisconnected { id } => { + println!("disconnecting client {:?}", id); + server.client_manager.clone().send_message(Remove(id)).await; + } + ServerMessage::ClientSendMessage { from, to, content } => { + server + .client_manager + .clone() + .send_message(SendMessage { from, to, content }) + .await + } + ServerMessage::ClientUpdate { to } => { + server + .client_manager + .clone() + .send_message(ClientMgrMessage::SendClients { to }) + .await + } + ServerMessage::ClientError { to } => { + server + .client_manager + .clone() + .send_message(ClientMgrMessage::SendError { to }) + .await } - ServerMessage::ClientSendMessage { from, to, content } => self - .client_manager - .send_message(SendMessage { from, to, content }), - ServerMessage::ClientUpdate (_uuid) => println!("not implemented"), } } } } } - -impl IPreemptive for Server { - fn run(arc: &std::sync::Arc) { - // start services - NetworkManager::start(&arc.network_manager); - ClientManager::start(&arc.client_manager); - loop { - arc.tick(); - } - } - - fn start(arc: &std::sync::Arc) { - let arc = arc.clone(); - // start thread - std::thread::spawn(move || Server::run(&arc)); - } -}