diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..8ed633e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,62 @@ +use crate::*; +use core::fmt; +use std::sync::Arc; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; + +#[derive(Debug)] +pub struct ClientInner { + id: ClientUid, + writer: OwnedWriteHalf, + dead: Option>, +} + +impl Drop for ClientInner { + fn drop(&mut self) { + debug!("client {} disconnected", self.id); + let dead_notif = self.dead.take().unwrap(); + tokio::spawn(async move { + dead_notif.send(()).await.unwrap(); + }); + } +} + +#[derive(Debug, Clone)] +pub struct Client { + inner: Arc>, +} + +impl Client { + pub fn new(id: usize, write: OwnedWriteHalf, close_tx: Sender<()>) -> Self { + Self { + inner: Arc::new(Mutex::new(ClientInner { + id: ClientUid::new(id), + writer: write, + dead: Some(close_tx), + })), + } + } + + pub async fn write(&self, s: Msg) -> Result<(), CmdErr> { + let mut l = self.inner.lock().await; + l.writer + .write(format!("{s}\r\n").as_bytes()) + .await + .map_err(|e| { + eprintln!("{e}"); + CmdErr::WriteFailed + }) + .map(|_| ()) + } + + pub async fn ok(&self) { + self.write("+OK").await.unwrap(); + } + + pub async fn id(&self) -> usize { + *self.inner.lock().await.id + } + pub async fn uid(&self) -> ClientUid { + self.inner.lock().await.id + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..bd91378 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,17 @@ +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum CmdErr { + #[error("Unknown Protocol Operation")] + UnknownCmd, + #[error("Unknown subject selector token")] + BadToken, + #[error("Missing argument")] + MissingArg, + #[error("Missing subject")] + MissingSubject, + #[error("Missing reply to")] + MissingReplyTo, + #[error("Missing payload length")] + ExpectLen, + #[error("Unable to write to client")] + WriteFailed, +} diff --git a/src/input.rs b/src/input.rs new file mode 100644 index 0000000..5ef7007 --- /dev/null +++ b/src/input.rs @@ -0,0 +1,150 @@ +use crate::*; + +pub trait ParseCommand +where + Self: Sized, +{ + async fn parse_command( + s: &str, + buf: &mut Vec, + rx: &mut OwnedReadHalf, + ) -> Result; +} + +#[derive(Debug)] +pub struct ConnectArgs {} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] +pub struct Subscribe { + pub subject: SubjectMatcher, + pub queue_group: Option, + pub sub_uid: SubUid, +} + +impl ParseCommand for Subscribe { + #[tracing::instrument(skip(_buf, _rx))] + async fn parse_command( + s: &str, + _buf: &mut Vec, + _rx: &mut OwnedReadHalf, + ) -> Result { + let has_queue_group = s.split_whitespace().count() == 4; + let mut parts = s.split_whitespace(); + let _cmd = parts.next().ok_or(CmdErr::UnknownCmd)?; + debug!("parse SUB"); + + let subject = + SubjectMatcher::compile(parts.next().ok_or(CmdErr::MissingSubject)?.to_owned())?; + debug!("MSG channel is {subject:?}"); + + let queue_group = if has_queue_group { + Some(QueueGroup::new( + parts.next().ok_or(CmdErr::MissingSubject)?.to_owned(), + )) + } else { + None + }; + + let subscription_uid = SubUid::new(parts.next().ok_or(CmdErr::ExpectLen)?.to_owned()); + Ok(Self { + subject, + sub_uid: subscription_uid, + queue_group, + }) + } +} + +#[derive(Debug)] +pub struct UnSubscribe { + pub subscription_uid: SubUid, +} + +#[derive(Debug)] +pub struct Publish { + pub subject: ChannelName, + pub reply_to: Option, + pub payload: Payload, +} + +impl ParseCommand for Publish { + #[tracing::instrument(skip(buf, rx))] + async fn parse_command( + s: &str, + buf: &mut Vec, + rx: &mut OwnedReadHalf, + ) -> Result { + let has_reply_to = s + .split_whitespace() + .position(|s| s.parse::().is_ok()) + .unwrap_or(0) + == 3; + let mut parts = s.split_whitespace(); + + let _cmd = parts.next().ok_or(CmdErr::UnknownCmd)?; + let subject = ChannelName::new(parts.next().ok_or(CmdErr::MissingSubject)?.to_owned()); + let reply_to = if has_reply_to { + Some(ReplyTo::new( + parts.next().ok_or(CmdErr::MissingReplyTo)?.to_owned(), + )) + } else { + None + }; + + let len: usize = parts + .next() + .ok_or(CmdErr::ExpectLen)? + .parse() + .map_err(|_| CmdErr::ExpectLen)?; + let Ok(_) = rx.read_exact(&mut buf[..len]).await else { + return Err(CmdErr::MissingArg); + }; + let Ok(payload) = std::str::from_utf8(&buf[..len]) else { + return Err(CmdErr::MissingArg); + }; + Ok(Self { + subject, + payload: Payload::new(payload.to_owned()), + reply_to, + }) + } +} + +#[derive(Debug)] +pub enum Cmd { + Connect(ConnectArgs), + Ping, + Pong, + Subscribe(Subscribe), + UnSubscribe(UnSubscribe), + Publish(Publish), +} + +impl ParseCommand for Cmd { + #[tracing::instrument(skip(buf, rx))] + async fn parse_command( + s: &str, + buf: &mut Vec, + rx: &mut OwnedReadHalf, + ) -> Result { + debug!("received string {s:?} as a command"); + let mut parts = s.split_whitespace(); + let cmd = parts.next().ok_or(CmdErr::UnknownCmd)?; + if cmd.eq_ignore_ascii_case("connect") { + let args = ConnectArgs {}; + // TODO: args??? + Ok(Cmd::Connect(args)) + } else if cmd.eq_ignore_ascii_case("ping") { + Ok(Cmd::Ping) + } else if cmd.eq_ignore_ascii_case("pong") { + Ok(Cmd::Pong) + } else if cmd.eq_ignore_ascii_case("sub") { + let s = Subscribe::parse_command(s, buf, rx).await?; + Ok(Cmd::Subscribe(s)) + } else if cmd.eq_ignore_ascii_case("pub") { + let p = Publish::parse_command(s, buf, rx).await?; + Ok(Cmd::Publish(p)) + } else { + Err(CmdErr::UnknownCmd) + } + } +} diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..60f7d0d --- /dev/null +++ b/src/model.rs @@ -0,0 +1,224 @@ +use derive_more::{Constructor, Deref, Display}; + +use crate::CmdErr; + +#[derive(Debug, Display, Constructor, Deref)] +pub struct ChannelName(String); + +#[derive(Debug, Display, Constructor, Deref)] +pub struct ReplyTo(String); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)] +pub struct QueueGroup(String); + +#[derive(Debug, Display, Constructor, Deref)] +pub struct Payload(String); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)] +pub struct SubUid(String); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)] +pub struct ClientUid(usize); + +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord, Display)] +pub enum SubjectToken { + Word(String), + AnyWord, + AnyTail, +} + +pub enum SubjectMatch { + Invalid, + Valid, + Fulfilled, +} + +impl SubjectToken { + fn is_match(&self, s: Option<&str>) -> SubjectMatch { + let Some(s) = s else { + return SubjectMatch::Invalid; + }; + match self { + Self::AnyTail => SubjectMatch::Fulfilled, + Self::AnyWord => SubjectMatch::Valid, + Self::Word(word) if word == s => SubjectMatch::Valid, + _ => SubjectMatch::Invalid, + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord, Deref)] +pub struct SubjectMatcher(Vec); + +impl SubjectMatcher { + pub fn compile>(s: S) -> Result { + let s: String = s.into(); + if s.split_whitespace().count() > 1 { + return Err(CmdErr::BadToken); + } + let len = s.split('.').count(); + let tokens: Vec<_> = + s.split('.') + .into_iter() + .try_fold(Vec::with_capacity(len), |mut v, s| { + let t = match s { + "*" => SubjectToken::AnyWord, + ">" => SubjectToken::AnyTail, + _ if s.contains('*') || s.contains('>') => return Err(CmdErr::BadToken), + _ => SubjectToken::Word(s.to_owned()), + }; + v.push(t); + Ok(v) + })?; + Ok(Self(tokens)) + } + + pub fn eq_with_chan(&self, chan: &str) -> bool { + let mut parts = chan.split('.'); + let len = self.0.len(); + self.0 + .iter() + .enumerate() + .find_map(|(idx, token)| match token.is_match(parts.next()) { + SubjectMatch::Fulfilled => Some(true), + SubjectMatch::Invalid => Some(false), + SubjectMatch::Valid if idx + 1 == len => Some(true), + SubjectMatch::Valid => None, + }) + .unwrap_or(false) + } +} + +#[cfg(test)] +mod sub_tests { + use std::sync::Once; + + use super::*; + use crate::*; + + static LOGGER: Once = Once::new(); + + #[test] + fn simple_word() { + LOGGER.call_once(init_logger); + + assert_eq!( + SubjectMatcher::compile("foo").unwrap().eq_with_chan("foo"), + true + ); + assert_eq!( + SubjectMatcher::compile("foo").unwrap().eq_with_chan("fo"), + false + ); + assert_eq!( + SubjectMatcher::compile("foo").unwrap().eq_with_chan("oo"), + false + ); + } + + #[test] + fn multiple_words() { + LOGGER.call_once(init_logger); + + assert_eq!( + SubjectMatcher::compile("foo.bar") + .unwrap() + .eq_with_chan("foo.bar"), + true + ); + assert_eq!( + SubjectMatcher::compile("foo.bar") + .unwrap() + .eq_with_chan("fo"), + false + ); + assert_eq!( + SubjectMatcher::compile("foo.bar") + .unwrap() + .eq_with_chan("oo"), + false + ); + } + + #[test] + fn space_at_selector() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("foo bar"), Err(CmdErr::BadToken)); + } + #[test] + fn tab_at_selector() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("foo\tbar"), Err(CmdErr::BadToken)); + } + #[test] + fn newline_at_selector() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("foo\nbar"), Err(CmdErr::BadToken)); + } + + #[test] + fn single_asterix() { + LOGGER.call_once(init_logger); + SubjectMatcher::compile("*").unwrap(); + SubjectMatcher::compile("foo.*").unwrap(); + SubjectMatcher::compile("foo.*.bar").unwrap(); + SubjectMatcher::compile("foo.*.bar.*").unwrap(); + } + #[test] + fn asterix_at_word_start() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("*oo"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.*oo"), Err(CmdErr::BadToken)); + } + #[test] + fn asterix_at_word_end() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("oo*"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.oo*"), Err(CmdErr::BadToken)); + } + #[test] + fn asterix_at_word_middle() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("o*o"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.o*o"), Err(CmdErr::BadToken)); + } + #[test] + fn multiple_asterix_at_word() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("a*b*c"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.a*b*c"), Err(CmdErr::BadToken)); + } + + #[test] + fn single_tail() { + LOGGER.call_once(init_logger); + SubjectMatcher::compile(">").unwrap(); + SubjectMatcher::compile("foo.>").unwrap(); + SubjectMatcher::compile("foo.>.bar").unwrap(); + SubjectMatcher::compile("foo.>.bar.>").unwrap(); + } + #[test] + fn tail_at_word_start() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile(">oo"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.>oo"), Err(CmdErr::BadToken)); + } + #[test] + fn tail_at_word_end() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("oo>"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.oo>"), Err(CmdErr::BadToken)); + } + #[test] + fn tail_at_word_middle() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("o>o"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.o>o"), Err(CmdErr::BadToken)); + } + #[test] + fn multiple_tail_at_word() { + LOGGER.call_once(init_logger); + assert_eq!(SubjectMatcher::compile("a>b>c"), Err(CmdErr::BadToken)); + assert_eq!(SubjectMatcher::compile("bar.a>b>c"), Err(CmdErr::BadToken)); + } +} diff --git a/src/output.rs b/src/output.rs new file mode 100644 index 0000000..5ded3be --- /dev/null +++ b/src/output.rs @@ -0,0 +1,59 @@ +use derive_more::Constructor; +use std::fmt; + +use crate::model::*; + +#[derive(serde::Serialize, Debug)] +pub struct ConnectionInfo { + pub server_id: String, + pub server_name: String, + pub version: &'static str, + pub proto: u16, + pub host: String, + pub port: u16, + pub client_ip: String, + pub client_id: usize, + pub max_payload: usize, +} + +impl fmt::Display for ConnectionInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(&serde_json::to_string(self).unwrap()) + } +} + +#[derive(Debug)] +pub struct Msg<'chan, 'reply_to, 'subscribe_id, 'payload> { + pub subject: &'chan ChannelName, + pub reply_to: &'reply_to Option, + pub sub_uid: &'subscribe_id SubUid, + pub payload: &'payload Payload, +} + +impl<'chan, 'reply_to, 'subscribe_id, 'payload> fmt::Display + for Msg<'chan, 'reply_to, 'subscribe_id, 'payload> +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(&format!("MSG {}", self.subject))?; + if let Some(reply_to) = &self.reply_to { + f.write_str(&format!(" {}", reply_to))?; + } + f.write_fmt(format_args!( + " {} {} {}", + self.sub_uid, + self.payload.len(), + self.payload + ))?; + Ok(()) + } +} + +#[derive(Debug, Constructor)] +pub struct Info(T); + +impl fmt::Display for Info { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str("INFO ")?; + fmt::Display::fmt(&self.0, f) + } +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..e0d12e2 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,116 @@ +use tokio::sync::mpsc::Sender; + +use crate::*; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct SubInfo { + subscribe: Subscribe, + client_id: ClientUid, +} + +#[derive(Debug, Clone, Default)] +pub struct Server { + clients: Arc>>, + subcriptions: Arc>>, + client_counter: Arc>, + pub addr: String, + pub port: u16, + pub max_payload: usize, +} + +impl Server { + pub fn new(addr: &str, port: u16, max_payload: usize) -> Self { + Self { + clients: Default::default(), + subcriptions: Default::default(), + client_counter: Default::default(), + addr: addr.into(), + port, + max_payload, + } + } +} + +impl Server { + pub async fn create_client(&self, stream: TcpStream) -> (OwnedReadHalf, Client, Sender<()>) { + let (reader, writer) = stream.into_split(); + let mut l = self.client_counter.lock().await; + *l += 1; + let id = *l; + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(1); + + let client = Client::new(id, writer, tx.clone()); + + let subs = self.subcriptions.clone(); + tokio::spawn(async move { + if !rx.recv().await.is_some() { + return; + } + let mut l = subs.write().await; + l.retain(|SubInfo { client_id, .. }| **client_id != id); + }); + self.clients + .write() + .await + .insert(client.uid().await, client.clone()); + (reader, client, tx) + } + + pub async fn remove(&self, client: Client) { + self.clients.write().await.remove(&client.uid().await); + } + + pub async fn add_sub(&self, client_id: ClientUid, subscribe: Subscribe) { + let mut sub_lock = self.subcriptions.write().await; + sub_lock.insert(SubInfo { + subscribe, + client_id, + }); + } + + pub async fn rm_sub(&self, sub_uid: SubUid, sub: SubjectMatcher) { + let mut sub_lock = self.subcriptions.write().await; + sub_lock.retain(|SubInfo { subscribe, .. }| { + &sub != &subscribe.subject && sub_uid != subscribe.sub_uid + }); + } + + pub async fn publish(&self, publish: Publish) { + use futures::prelude::*; + + let sub_lock = self.subcriptions.read().await; + let cl_lock = self.clients.read().await; + futures::stream::iter( + sub_lock + .iter() + .filter(|SubInfo { subscribe, .. }| { + subscribe.subject.eq_with_chan(&*publish.subject) + }) + .filter_map( + |SubInfo { + subscribe, + client_id, + .. + }| cl_lock.get(client_id).zip(Some(&subscribe.sub_uid)), + ), + ) + .map(|(cl, sub_uid)| { + debug!("write response MSG for subject uid {sub_uid} and {publish:?}"); + cl.write(format!( + "{}", + Msg { + subject: &publish.subject, + reply_to: &publish.reply_to, + sub_uid, + payload: &publish.payload + } + )) + }) + .for_each_concurrent(None, |rx| async move { + if let Err(e) = rx.await { + tracing::error!("Failed to write to client: {e}"); + } + }) + .await; + } +} diff --git a/tests/build_and_boot.rs b/tests/build_and_boot.rs new file mode 100644 index 0000000..6d9e253 --- /dev/null +++ b/tests/build_and_boot.rs @@ -0,0 +1,35 @@ +pub async fn build_bin() { + std::process::Command::new("cargo") + .arg("build") + .arg("--bin") + .arg("nats-server") + .spawn() + .unwrap() + .wait() + .unwrap(); +} + +pub struct NatsServer(std::process::Child); + +impl Drop for NatsServer { + fn drop(&mut self) { + self.0.kill().unwrap(); + } +} + +pub async fn boot_bin() -> NatsServer { + std::process::Command::new("killall") + .arg("nats-server") + .spawn() + .unwrap() + .wait() + .unwrap(); + let child = std::process::Command::new("./target/debug/nats-server") + .env("RUST_LOG", "debug") + .spawn() + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(600)).await; + + NatsServer(child) +} diff --git a/tests/test_connect.rs b/tests/test_connect.rs new file mode 100644 index 0000000..c026c5a --- /dev/null +++ b/tests/test_connect.rs @@ -0,0 +1,81 @@ +use std::{str::*, time::Duration}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +include!("./build_and_boot.rs"); + +macro_rules! read_text { + ($r: expr, $buffer: expr) => {{ + println!("<- Reading response"); + let len = tokio::time::timeout(Duration::from_millis(100), $r.read($buffer)) + .await + .expect("timeout!") + .expect("failed to read"); + + let res = from_utf8(&$buffer[..len]) + .expect("not a string") + .to_string(); + $buffer.fill(0); + res + }}; +} +macro_rules! ex_txt { + ($r: expr, $buffer: expr, $txt: expr) => {{ + let res = read_text!($r, $buffer); + assert_eq!(res.as_str(), $txt); + $buffer.fill(0); + }}; +} +macro_rules! ex_ok { + ($r: expr, $buffer: expr) => { + ex_txt!($r, $buffer, "+OK\r\n") + }; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_connect() { + build_bin().await; + + println!("starting nats-server"); + let _server = boot_bin().await; + + println!("starting nats client"); + let client = tokio::net::TcpStream::connect(("0.0.0.0", 4222)) + .await + .unwrap(); + client.set_ttl(3).unwrap(); + let (mut r, mut w) = client.into_split(); + + let mut buffer = [0; 1024]; + read_text!(&mut r, &mut buffer); + + println!("-> Sending connect"); + w.write(b"connect {}\r\n").await.unwrap(); + ex_ok!(&mut r, &mut buffer); + println!(" OK"); + + println!("-> Sending SUB"); + w.write(b"sub foo.* 90\r\n").await.unwrap(); + ex_ok!(&mut r, &mut buffer); + println!(" OK"); + + println!("-> Sending PUB (1)"); + w.write(b"pub foo.bar 5\r\n").await.unwrap(); + w.write(b"world\r\n").await.unwrap(); + ex_ok!(&mut r, &mut buffer); + ex_txt!(&mut r, &mut buffer, "MSG foo.bar 90 5 world\r\n"); + println!(" OK"); + + println!("-> Sending PUB (2)"); + w.write(b"pub foo.bar 5\r\n").await.unwrap(); + w.write(b"world\r\n").await.unwrap(); + ex_ok!(&mut r, &mut buffer); + ex_txt!(&mut r, &mut buffer, "MSG foo.bar 90 5 world\r\n"); + println!(" OK"); + + println!("-> Sending PUB REPLY-TO"); + w.write(b"pub foo.bar home 5\r\n").await.unwrap(); + w.write(b"world\r\n").await.unwrap(); + ex_ok!(&mut r, &mut buffer); + ex_txt!(&mut r, &mut buffer, "MSG foo.bar home 90 5 world\r\n"); + println!(" OK"); +}