diff --git a/src/buffers.rs b/src/buffers.rs index 95a45c2..d468847 100644 --- a/src/buffers.rs +++ b/src/buffers.rs @@ -1,21 +1,34 @@ +use crate::CmdErr; +use tokio::net::tcp::OwnedReadHalf; + struct ReadState { - // rx: OwnedReadHalf, buffer: Vec, - cursor: usize, + read_cursor: usize, + write_cursor: usize, } impl ReadState { + fn new(len: usize) -> Self { + let buffer = vec![0; len]; + Self { + buffer, + read_cursor: 0, + write_cursor: 0, + } + } + fn has_next(&self) -> bool { - self.buffer[..self.cursor] + self.buffer[..self.read_cursor] .iter() .position(|v| *v == b'\n') .is_some() } - fn shift_ramaining(&mut self) { + + fn shift_remaining(&mut self) { if !self.has_next() { return; } - let pos = self.cursor; + let pos = self.read_cursor; let mut last_old = 0; let mut last_new = 0; for (new, old) in (pos..self.buffer.len()).enumerate() { @@ -33,32 +46,86 @@ impl ReadState { // eprintln!("erase {old}"); self.buffer[old] = 0; } - self.cursor = 0; + self.read_cursor = 0; + self.write_cursor = self.buffer.iter().position(|n| *n == 0).unwrap_or_default(); } - fn next(&mut self) -> Option<&[u8]> { - let Some(end) = self.buffer[self.cursor..self.buffer.len()] + + fn next<'s>(&'s mut self) -> Option<&'s [u8]> { + let Some(end) = self.buffer[self.read_cursor..self.buffer.len()] .iter() - .position(|v| *v == b'\n').map(|n| n + self.cursor) else { + .position(|v| *v == b'\n').map(|n| n + self.read_cursor) else { return None; }; // eprintln!("cursor {} end {end}", self.cursor); - let s = &self.buffer[self.cursor..=end]; - self.cursor = (end + 1).min(self.buffer.len() - 1); + let s = &self.buffer[self.read_cursor..=end]; + self.read_cursor = (end + 1).min(self.buffer.len()); Some(s) } + + fn write(&mut self, buffer: &[u8]) -> Result { + if self.write_cursor + buffer.len() > self.buffer.len() { + return Err(CmdErr::BufferFull); + } + let wc = self.write_cursor; + for (idx, b) in buffer.iter().copied().enumerate() { + self.buffer[idx + wc] = b; + } + self.write_cursor = wc + buffer.len(); + Ok(self.write_cursor) + } +} + +struct TcpReadState<'read, 'buffer> { + reader: &'read mut OwnedReadHalf, + buffer: &'buffer mut ReadState, +} + +impl<'read, 'buffer> TcpReadState<'read, 'buffer> { + pub fn new(reader: &'read mut OwnedReadHalf, buffer: &'buffer mut ReadState) -> Self { + Self { + reader, + buffer, + } + } + + pub async fn next(&mut self) -> Option<&'buffer [u8]> { + use tokio::io::AsyncReadExt; + use tracing::*; + { + if let Some(v) = self.buffer.next() { + return Some(v); + } + } + + { + match self.reader.read_buf(&mut self.buffer.buffer).await { + Ok(n) => { + debug!("received {n} bytes"); + } + Err(e) => { + debug!("Failed to read from rx: {e}"); + return None; + } + }; + } + { + self.buffer.next() + } + } } #[cfg(test)] -mod tests { +mod read_test { use super::*; #[test] fn shift_empty() { let mut state = ReadState { buffer: vec![0, 0, 0, 0, 0, 0, 0, 0], - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; - state.shift_ramaining(); + state.shift_remaining(); let expected = vec![0, 0, 0, 0, 0, 0, 0, 0]; assert_eq!( std::str::from_utf8(&state.buffer), @@ -69,9 +136,10 @@ mod tests { fn shift_single_unterminated() { let mut state = ReadState { buffer: vec![b'h', b'e', b'l', b'l', b'o', 0, 0, 0], - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; - state.shift_ramaining(); + state.shift_remaining(); let expected = vec![b'h', b'e', b'l', b'l', b'o', 0, 0, 0]; assert_eq!( std::str::from_utf8(&state.buffer), @@ -82,9 +150,10 @@ mod tests { fn shift_single_terminated() { let mut state = ReadState { buffer: vec![b'h', b'e', b'l', b'l', b'o', b'\r', b'\n', 0], - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; - state.shift_ramaining(); + state.shift_remaining(); let expected = vec![b'h', b'e', b'l', b'l', b'o', b'\r', b'\n', 0]; assert_eq!( std::str::from_utf8(&state.buffer), @@ -95,9 +164,10 @@ mod tests { fn shift_single_unterminated_moved() { let mut state = ReadState { buffer: b"hello\r\nworld\0\0\0\0".to_vec(), - cursor: 7, + read_cursor: 7, + write_cursor: 0, }; - state.shift_ramaining(); + state.shift_remaining(); let expected = b"world\0\0\0\0\0\0\0\0\0\0\0".to_vec(); assert_eq!( std::str::from_utf8(&state.buffer), @@ -108,9 +178,10 @@ mod tests { fn shift_single_terminated_moved() { let mut state = ReadState { buffer: b"hello\r\nworld\r\n\0\0\0\0\0\0\0\0".to_vec(), - cursor: 7, + read_cursor: 7, + write_cursor: 0, }; - state.shift_ramaining(); + state.shift_remaining(); let expected = b"world\r\n\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec(); assert_eq!( std::str::from_utf8(&state.buffer), @@ -122,7 +193,8 @@ mod tests { fn empty_next() { let mut state = ReadState { buffer: b"\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; assert_eq!(state.next(), None); } @@ -130,7 +202,8 @@ mod tests { fn depleted_next() { let mut state = ReadState { buffer: b"hello\0\0\0\0\0\0\0\0".to_vec(), - cursor: 5, + read_cursor: 5, + write_cursor: 0, }; assert_eq!(state.next(), None); } @@ -138,35 +211,38 @@ mod tests { fn one_world_next() { let mut state = ReadState { buffer: b"hello\r\n\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; assert_eq!(state.next(), Some(b"hello\r\n".as_slice())); - assert_eq!(state.cursor, 7); + assert_eq!(state.read_cursor, 7); } #[test] fn tree_world_next() { let mut state = ReadState { buffer: b"hello\r\nworld\r\nnats\r\n\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; assert_eq!(state.next(), Some(b"hello\r\n".as_slice())); - assert_eq!(state.cursor, 7); + assert_eq!(state.read_cursor, 7); assert_eq!(state.next(), Some(b"world\r\n".as_slice())); - assert_eq!(state.cursor, 14); + assert_eq!(state.read_cursor, 14); assert_eq!(state.next(), Some(b"nats\r\n".as_slice())); - assert_eq!(state.cursor, 20); + assert_eq!(state.read_cursor, 20); assert_eq!(state.next(), None); - assert_eq!(state.cursor, 20); + assert_eq!(state.read_cursor, 20); } #[test] fn partial_move_and_shift() { let mut state = ReadState { buffer: b"hello\r\nworld\r\nnats\r\n\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; state.next(); state.next(); - state.shift_ramaining(); + state.shift_remaining(); let expected = b"nats\r\n\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec(); assert_eq!( std::str::from_utf8(&state.buffer), @@ -177,12 +253,13 @@ mod tests { fn move_and_shift_non_terminated() { let mut state = ReadState { buffer: b"hello\r\nworld\r\nnats\0\0\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; state.next(); state.next(); state.next(); - state.shift_ramaining(); + state.shift_remaining(); let expected = b"nats\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec(); assert_eq!( std::str::from_utf8(&state.buffer), @@ -193,13 +270,14 @@ mod tests { fn full_move_and_shift() { let mut state = ReadState { buffer: b"hello\r\nworld\r\nnats\r\n\0\0\0\0\0\0\0\0".to_vec(), - cursor: 0, + read_cursor: 0, + write_cursor: 0, }; state.next(); state.next(); state.next(); - assert_eq!(state.cursor, 20); - state.shift_ramaining(); + assert_eq!(state.read_cursor, 20); + state.shift_remaining(); let expected = b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec(); assert_eq!( std::str::from_utf8(&state.buffer), @@ -207,3 +285,38 @@ mod tests { ); } } + +#[cfg(test)] +mod write_buffer { + use super::*; + + #[test] + fn write_2_msg() { + let mut state = ReadState::new(6); + assert_eq!(state.write(b"ab"), Ok(2)); + assert_eq!(state.write(b"cd"), Ok(4)); + assert_eq!(state.buffer.as_slice(), b"abcd\0\0"); + } + + #[test] + fn write_too_much() { + let mut state = ReadState::new(6); + assert_eq!(state.write(b"ab"), Ok(2)); + assert_eq!(state.write(b"cd"), Ok(4)); + assert_eq!(state.write(b"efg"), Err(CmdErr::BufferFull)); + } + #[test] + fn write_after_shift() { + let mut state = ReadState::new(6); + assert_eq!(state.write(b"a\r\n"), Ok(3)); + assert_eq!(state.write(b"c\r\n"), Ok(6)); + while let Some(s) = state.next() { + eprintln!("next {s:?}"); + } + state.shift_remaining(); + assert_eq!(state.buffer.as_slice(), b"\0\0\0\0\0\0"); + let res = state.write(b"efg"); + assert_eq!(state.buffer.as_slice(), b"efg\0\0\0"); + assert_eq!(res, Ok(3)); + } +} diff --git a/src/error.rs b/src/error.rs index bd91378..318aa4d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,4 +14,6 @@ pub enum CmdErr { ExpectLen, #[error("Unable to write to client")] WriteFailed, + #[error("Server is overloaded and can't handle message")] + BufferFull, }