Working basic server

This commit is contained in:
eraden 2023-09-16 21:55:22 +02:00
parent 05e8baa6fb
commit 4084eb8df2
8 changed files with 744 additions and 0 deletions

62
src/client.rs Normal file
View File

@ -0,0 +1,62 @@
use crate::*;
use core::fmt;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
#[derive(Debug)]
pub struct ClientInner {
id: ClientUid,
writer: OwnedWriteHalf,
dead: Option<Sender<()>>,
}
impl Drop for ClientInner {
fn drop(&mut self) {
debug!("client {} disconnected", self.id);
let dead_notif = self.dead.take().unwrap();
tokio::spawn(async move {
dead_notif.send(()).await.unwrap();
});
}
}
#[derive(Debug, Clone)]
pub struct Client {
inner: Arc<Mutex<ClientInner>>,
}
impl Client {
pub fn new(id: usize, write: OwnedWriteHalf, close_tx: Sender<()>) -> Self {
Self {
inner: Arc::new(Mutex::new(ClientInner {
id: ClientUid::new(id),
writer: write,
dead: Some(close_tx),
})),
}
}
pub async fn write<Msg: fmt::Display>(&self, s: Msg) -> Result<(), CmdErr> {
let mut l = self.inner.lock().await;
l.writer
.write(format!("{s}\r\n").as_bytes())
.await
.map_err(|e| {
eprintln!("{e}");
CmdErr::WriteFailed
})
.map(|_| ())
}
pub async fn ok(&self) {
self.write("+OK").await.unwrap();
}
pub async fn id(&self) -> usize {
*self.inner.lock().await.id
}
pub async fn uid(&self) -> ClientUid {
self.inner.lock().await.id
}
}

17
src/error.rs Normal file
View File

@ -0,0 +1,17 @@
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum CmdErr {
#[error("Unknown Protocol Operation")]
UnknownCmd,
#[error("Unknown subject selector token")]
BadToken,
#[error("Missing argument")]
MissingArg,
#[error("Missing subject")]
MissingSubject,
#[error("Missing reply to")]
MissingReplyTo,
#[error("Missing payload length")]
ExpectLen,
#[error("Unable to write to client")]
WriteFailed,
}

150
src/input.rs Normal file
View File

@ -0,0 +1,150 @@
use crate::*;
pub trait ParseCommand
where
Self: Sized,
{
async fn parse_command(
s: &str,
buf: &mut Vec<u8>,
rx: &mut OwnedReadHalf,
) -> Result<Self, CmdErr>;
}
#[derive(Debug)]
pub struct ConnectArgs {}
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
pub struct Subscribe {
pub subject: SubjectMatcher,
pub queue_group: Option<QueueGroup>,
pub sub_uid: SubUid,
}
impl ParseCommand for Subscribe {
#[tracing::instrument(skip(_buf, _rx))]
async fn parse_command(
s: &str,
_buf: &mut Vec<u8>,
_rx: &mut OwnedReadHalf,
) -> Result<Self, CmdErr> {
let has_queue_group = s.split_whitespace().count() == 4;
let mut parts = s.split_whitespace();
let _cmd = parts.next().ok_or(CmdErr::UnknownCmd)?;
debug!("parse SUB");
let subject =
SubjectMatcher::compile(parts.next().ok_or(CmdErr::MissingSubject)?.to_owned())?;
debug!("MSG channel is {subject:?}");
let queue_group = if has_queue_group {
Some(QueueGroup::new(
parts.next().ok_or(CmdErr::MissingSubject)?.to_owned(),
))
} else {
None
};
let subscription_uid = SubUid::new(parts.next().ok_or(CmdErr::ExpectLen)?.to_owned());
Ok(Self {
subject,
sub_uid: subscription_uid,
queue_group,
})
}
}
#[derive(Debug)]
pub struct UnSubscribe {
pub subscription_uid: SubUid,
}
#[derive(Debug)]
pub struct Publish {
pub subject: ChannelName,
pub reply_to: Option<ReplyTo>,
pub payload: Payload,
}
impl ParseCommand for Publish {
#[tracing::instrument(skip(buf, rx))]
async fn parse_command(
s: &str,
buf: &mut Vec<u8>,
rx: &mut OwnedReadHalf,
) -> Result<Self, CmdErr> {
let has_reply_to = s
.split_whitespace()
.position(|s| s.parse::<usize>().is_ok())
.unwrap_or(0)
== 3;
let mut parts = s.split_whitespace();
let _cmd = parts.next().ok_or(CmdErr::UnknownCmd)?;
let subject = ChannelName::new(parts.next().ok_or(CmdErr::MissingSubject)?.to_owned());
let reply_to = if has_reply_to {
Some(ReplyTo::new(
parts.next().ok_or(CmdErr::MissingReplyTo)?.to_owned(),
))
} else {
None
};
let len: usize = parts
.next()
.ok_or(CmdErr::ExpectLen)?
.parse()
.map_err(|_| CmdErr::ExpectLen)?;
let Ok(_) = rx.read_exact(&mut buf[..len]).await else {
return Err(CmdErr::MissingArg);
};
let Ok(payload) = std::str::from_utf8(&buf[..len]) else {
return Err(CmdErr::MissingArg);
};
Ok(Self {
subject,
payload: Payload::new(payload.to_owned()),
reply_to,
})
}
}
#[derive(Debug)]
pub enum Cmd {
Connect(ConnectArgs),
Ping,
Pong,
Subscribe(Subscribe),
UnSubscribe(UnSubscribe),
Publish(Publish),
}
impl ParseCommand for Cmd {
#[tracing::instrument(skip(buf, rx))]
async fn parse_command(
s: &str,
buf: &mut Vec<u8>,
rx: &mut OwnedReadHalf,
) -> Result<Self, CmdErr> {
debug!("received string {s:?} as a command");
let mut parts = s.split_whitespace();
let cmd = parts.next().ok_or(CmdErr::UnknownCmd)?;
if cmd.eq_ignore_ascii_case("connect") {
let args = ConnectArgs {};
// TODO: args???
Ok(Cmd::Connect(args))
} else if cmd.eq_ignore_ascii_case("ping") {
Ok(Cmd::Ping)
} else if cmd.eq_ignore_ascii_case("pong") {
Ok(Cmd::Pong)
} else if cmd.eq_ignore_ascii_case("sub") {
let s = Subscribe::parse_command(s, buf, rx).await?;
Ok(Cmd::Subscribe(s))
} else if cmd.eq_ignore_ascii_case("pub") {
let p = Publish::parse_command(s, buf, rx).await?;
Ok(Cmd::Publish(p))
} else {
Err(CmdErr::UnknownCmd)
}
}
}

224
src/model.rs Normal file
View File

@ -0,0 +1,224 @@
use derive_more::{Constructor, Deref, Display};
use crate::CmdErr;
#[derive(Debug, Display, Constructor, Deref)]
pub struct ChannelName(String);
#[derive(Debug, Display, Constructor, Deref)]
pub struct ReplyTo(String);
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)]
pub struct QueueGroup(String);
#[derive(Debug, Display, Constructor, Deref)]
pub struct Payload(String);
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)]
pub struct SubUid(String);
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Display, Constructor, Deref)]
pub struct ClientUid(usize);
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord, Display)]
pub enum SubjectToken {
Word(String),
AnyWord,
AnyTail,
}
pub enum SubjectMatch {
Invalid,
Valid,
Fulfilled,
}
impl SubjectToken {
fn is_match(&self, s: Option<&str>) -> SubjectMatch {
let Some(s) = s else {
return SubjectMatch::Invalid;
};
match self {
Self::AnyTail => SubjectMatch::Fulfilled,
Self::AnyWord => SubjectMatch::Valid,
Self::Word(word) if word == s => SubjectMatch::Valid,
_ => SubjectMatch::Invalid,
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord, Deref)]
pub struct SubjectMatcher(Vec<SubjectToken>);
impl SubjectMatcher {
pub fn compile<S: Into<String>>(s: S) -> Result<Self, CmdErr> {
let s: String = s.into();
if s.split_whitespace().count() > 1 {
return Err(CmdErr::BadToken);
}
let len = s.split('.').count();
let tokens: Vec<_> =
s.split('.')
.into_iter()
.try_fold(Vec::with_capacity(len), |mut v, s| {
let t = match s {
"*" => SubjectToken::AnyWord,
">" => SubjectToken::AnyTail,
_ if s.contains('*') || s.contains('>') => return Err(CmdErr::BadToken),
_ => SubjectToken::Word(s.to_owned()),
};
v.push(t);
Ok(v)
})?;
Ok(Self(tokens))
}
pub fn eq_with_chan(&self, chan: &str) -> bool {
let mut parts = chan.split('.');
let len = self.0.len();
self.0
.iter()
.enumerate()
.find_map(|(idx, token)| match token.is_match(parts.next()) {
SubjectMatch::Fulfilled => Some(true),
SubjectMatch::Invalid => Some(false),
SubjectMatch::Valid if idx + 1 == len => Some(true),
SubjectMatch::Valid => None,
})
.unwrap_or(false)
}
}
#[cfg(test)]
mod sub_tests {
use std::sync::Once;
use super::*;
use crate::*;
static LOGGER: Once = Once::new();
#[test]
fn simple_word() {
LOGGER.call_once(init_logger);
assert_eq!(
SubjectMatcher::compile("foo").unwrap().eq_with_chan("foo"),
true
);
assert_eq!(
SubjectMatcher::compile("foo").unwrap().eq_with_chan("fo"),
false
);
assert_eq!(
SubjectMatcher::compile("foo").unwrap().eq_with_chan("oo"),
false
);
}
#[test]
fn multiple_words() {
LOGGER.call_once(init_logger);
assert_eq!(
SubjectMatcher::compile("foo.bar")
.unwrap()
.eq_with_chan("foo.bar"),
true
);
assert_eq!(
SubjectMatcher::compile("foo.bar")
.unwrap()
.eq_with_chan("fo"),
false
);
assert_eq!(
SubjectMatcher::compile("foo.bar")
.unwrap()
.eq_with_chan("oo"),
false
);
}
#[test]
fn space_at_selector() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("foo bar"), Err(CmdErr::BadToken));
}
#[test]
fn tab_at_selector() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("foo\tbar"), Err(CmdErr::BadToken));
}
#[test]
fn newline_at_selector() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("foo\nbar"), Err(CmdErr::BadToken));
}
#[test]
fn single_asterix() {
LOGGER.call_once(init_logger);
SubjectMatcher::compile("*").unwrap();
SubjectMatcher::compile("foo.*").unwrap();
SubjectMatcher::compile("foo.*.bar").unwrap();
SubjectMatcher::compile("foo.*.bar.*").unwrap();
}
#[test]
fn asterix_at_word_start() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("*oo"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.*oo"), Err(CmdErr::BadToken));
}
#[test]
fn asterix_at_word_end() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("oo*"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.oo*"), Err(CmdErr::BadToken));
}
#[test]
fn asterix_at_word_middle() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("o*o"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.o*o"), Err(CmdErr::BadToken));
}
#[test]
fn multiple_asterix_at_word() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("a*b*c"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.a*b*c"), Err(CmdErr::BadToken));
}
#[test]
fn single_tail() {
LOGGER.call_once(init_logger);
SubjectMatcher::compile(">").unwrap();
SubjectMatcher::compile("foo.>").unwrap();
SubjectMatcher::compile("foo.>.bar").unwrap();
SubjectMatcher::compile("foo.>.bar.>").unwrap();
}
#[test]
fn tail_at_word_start() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile(">oo"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.>oo"), Err(CmdErr::BadToken));
}
#[test]
fn tail_at_word_end() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("oo>"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.oo>"), Err(CmdErr::BadToken));
}
#[test]
fn tail_at_word_middle() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("o>o"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.o>o"), Err(CmdErr::BadToken));
}
#[test]
fn multiple_tail_at_word() {
LOGGER.call_once(init_logger);
assert_eq!(SubjectMatcher::compile("a>b>c"), Err(CmdErr::BadToken));
assert_eq!(SubjectMatcher::compile("bar.a>b>c"), Err(CmdErr::BadToken));
}
}

59
src/output.rs Normal file
View File

@ -0,0 +1,59 @@
use derive_more::Constructor;
use std::fmt;
use crate::model::*;
#[derive(serde::Serialize, Debug)]
pub struct ConnectionInfo {
pub server_id: String,
pub server_name: String,
pub version: &'static str,
pub proto: u16,
pub host: String,
pub port: u16,
pub client_ip: String,
pub client_id: usize,
pub max_payload: usize,
}
impl fmt::Display for ConnectionInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.write_str(&serde_json::to_string(self).unwrap())
}
}
#[derive(Debug)]
pub struct Msg<'chan, 'reply_to, 'subscribe_id, 'payload> {
pub subject: &'chan ChannelName,
pub reply_to: &'reply_to Option<ReplyTo>,
pub sub_uid: &'subscribe_id SubUid,
pub payload: &'payload Payload,
}
impl<'chan, 'reply_to, 'subscribe_id, 'payload> fmt::Display
for Msg<'chan, 'reply_to, 'subscribe_id, 'payload>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.write_str(&format!("MSG {}", self.subject))?;
if let Some(reply_to) = &self.reply_to {
f.write_str(&format!(" {}", reply_to))?;
}
f.write_fmt(format_args!(
" {} {} {}",
self.sub_uid,
self.payload.len(),
self.payload
))?;
Ok(())
}
}
#[derive(Debug, Constructor)]
pub struct Info<T: fmt::Display + fmt::Debug>(T);
impl<T: fmt::Display + fmt::Debug> fmt::Display for Info<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.write_str("INFO ")?;
fmt::Display::fmt(&self.0, f)
}
}

116
src/server.rs Normal file
View File

@ -0,0 +1,116 @@
use tokio::sync::mpsc::Sender;
use crate::*;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
struct SubInfo {
subscribe: Subscribe,
client_id: ClientUid,
}
#[derive(Debug, Clone, Default)]
pub struct Server {
clients: Arc<RwLock<HashMap<ClientUid, Client>>>,
subcriptions: Arc<RwLock<BTreeSet<SubInfo>>>,
client_counter: Arc<Mutex<usize>>,
pub addr: String,
pub port: u16,
pub max_payload: usize,
}
impl Server {
pub fn new(addr: &str, port: u16, max_payload: usize) -> Self {
Self {
clients: Default::default(),
subcriptions: Default::default(),
client_counter: Default::default(),
addr: addr.into(),
port,
max_payload,
}
}
}
impl Server {
pub async fn create_client(&self, stream: TcpStream) -> (OwnedReadHalf, Client, Sender<()>) {
let (reader, writer) = stream.into_split();
let mut l = self.client_counter.lock().await;
*l += 1;
let id = *l;
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(1);
let client = Client::new(id, writer, tx.clone());
let subs = self.subcriptions.clone();
tokio::spawn(async move {
if !rx.recv().await.is_some() {
return;
}
let mut l = subs.write().await;
l.retain(|SubInfo { client_id, .. }| **client_id != id);
});
self.clients
.write()
.await
.insert(client.uid().await, client.clone());
(reader, client, tx)
}
pub async fn remove(&self, client: Client) {
self.clients.write().await.remove(&client.uid().await);
}
pub async fn add_sub(&self, client_id: ClientUid, subscribe: Subscribe) {
let mut sub_lock = self.subcriptions.write().await;
sub_lock.insert(SubInfo {
subscribe,
client_id,
});
}
pub async fn rm_sub(&self, sub_uid: SubUid, sub: SubjectMatcher) {
let mut sub_lock = self.subcriptions.write().await;
sub_lock.retain(|SubInfo { subscribe, .. }| {
&sub != &subscribe.subject && sub_uid != subscribe.sub_uid
});
}
pub async fn publish(&self, publish: Publish) {
use futures::prelude::*;
let sub_lock = self.subcriptions.read().await;
let cl_lock = self.clients.read().await;
futures::stream::iter(
sub_lock
.iter()
.filter(|SubInfo { subscribe, .. }| {
subscribe.subject.eq_with_chan(&*publish.subject)
})
.filter_map(
|SubInfo {
subscribe,
client_id,
..
}| cl_lock.get(client_id).zip(Some(&subscribe.sub_uid)),
),
)
.map(|(cl, sub_uid)| {
debug!("write response MSG for subject uid {sub_uid} and {publish:?}");
cl.write(format!(
"{}",
Msg {
subject: &publish.subject,
reply_to: &publish.reply_to,
sub_uid,
payload: &publish.payload
}
))
})
.for_each_concurrent(None, |rx| async move {
if let Err(e) = rx.await {
tracing::error!("Failed to write to client: {e}");
}
})
.await;
}
}

35
tests/build_and_boot.rs Normal file
View File

@ -0,0 +1,35 @@
pub async fn build_bin() {
std::process::Command::new("cargo")
.arg("build")
.arg("--bin")
.arg("nats-server")
.spawn()
.unwrap()
.wait()
.unwrap();
}
pub struct NatsServer(std::process::Child);
impl Drop for NatsServer {
fn drop(&mut self) {
self.0.kill().unwrap();
}
}
pub async fn boot_bin() -> NatsServer {
std::process::Command::new("killall")
.arg("nats-server")
.spawn()
.unwrap()
.wait()
.unwrap();
let child = std::process::Command::new("./target/debug/nats-server")
.env("RUST_LOG", "debug")
.spawn()
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(600)).await;
NatsServer(child)
}

81
tests/test_connect.rs Normal file
View File

@ -0,0 +1,81 @@
use std::{str::*, time::Duration};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
include!("./build_and_boot.rs");
macro_rules! read_text {
($r: expr, $buffer: expr) => {{
println!("<- Reading response");
let len = tokio::time::timeout(Duration::from_millis(100), $r.read($buffer))
.await
.expect("timeout!")
.expect("failed to read");
let res = from_utf8(&$buffer[..len])
.expect("not a string")
.to_string();
$buffer.fill(0);
res
}};
}
macro_rules! ex_txt {
($r: expr, $buffer: expr, $txt: expr) => {{
let res = read_text!($r, $buffer);
assert_eq!(res.as_str(), $txt);
$buffer.fill(0);
}};
}
macro_rules! ex_ok {
($r: expr, $buffer: expr) => {
ex_txt!($r, $buffer, "+OK\r\n")
};
}
#[tokio::test(flavor = "multi_thread")]
async fn test_connect() {
build_bin().await;
println!("starting nats-server");
let _server = boot_bin().await;
println!("starting nats client");
let client = tokio::net::TcpStream::connect(("0.0.0.0", 4222))
.await
.unwrap();
client.set_ttl(3).unwrap();
let (mut r, mut w) = client.into_split();
let mut buffer = [0; 1024];
read_text!(&mut r, &mut buffer);
println!("-> Sending connect");
w.write(b"connect {}\r\n").await.unwrap();
ex_ok!(&mut r, &mut buffer);
println!(" OK");
println!("-> Sending SUB");
w.write(b"sub foo.* 90\r\n").await.unwrap();
ex_ok!(&mut r, &mut buffer);
println!(" OK");
println!("-> Sending PUB (1)");
w.write(b"pub foo.bar 5\r\n").await.unwrap();
w.write(b"world\r\n").await.unwrap();
ex_ok!(&mut r, &mut buffer);
ex_txt!(&mut r, &mut buffer, "MSG foo.bar 90 5 world\r\n");
println!(" OK");
println!("-> Sending PUB (2)");
w.write(b"pub foo.bar 5\r\n").await.unwrap();
w.write(b"world\r\n").await.unwrap();
ex_ok!(&mut r, &mut buffer);
ex_txt!(&mut r, &mut buffer, "MSG foo.bar 90 5 world\r\n");
println!(" OK");
println!("-> Sending PUB REPLY-TO");
w.write(b"pub foo.bar home 5\r\n").await.unwrap();
w.write(b"world\r\n").await.unwrap();
ex_ok!(&mut r, &mut buffer);
ex_txt!(&mut r, &mut buffer, "MSG foo.bar home 90 5 world\r\n");
println!(" OK");
}