commit 9bfc723bcbc53d3b73ddd4ce2224c4a1076993b9 Author: eraden Date: Fri Jul 21 22:35:18 2023 +0200 Valid solution diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9ce42ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +/target + + +# Added by cargo +# +# already existing elements were commented out + +#/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0157a88 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "deduplicator" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1.72" +tokio = { version = "1.29.1", features = ["full"] } +reqwest = "0.11.18" +futures-channel = "0.3.28" + +[dev-dependencies] +httpmock = "0.6.8" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..2d8d1e4 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,267 @@ +#![allow(dead_code)] + +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use async_trait::async_trait; +use futures_channel::oneshot::{channel, Sender}; + +#[async_trait] +trait Deduplicator: Default + Send + Sync + 'static { + async fn fetch(&self, url: String) -> Result; +} + +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +struct Response { + status: u16, + body: String, +} + +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +struct Error { + status: Option, + message: String, +} + +type ReqRes = Result; +type ReqResQueue = Vec>; + +pub struct DeduplicatorClient { + pending: Arc>>, + client: reqwest::Client, +} + +impl Default for DeduplicatorClient { + fn default() -> Self { + Self { + pending: Arc::new(RwLock::new(HashMap::with_capacity(64))), + client: reqwest::ClientBuilder::default().build().unwrap(), + } + } +} + +#[async_trait] +impl Deduplicator for DeduplicatorClient { + async fn fetch(&self, url: String) -> ReqRes { + let is_pending = { + let lock = self.pending.read().unwrap(); + lock.contains_key(&url) + }; + if is_pending { + let rx = { + let mut lock = self.pending.write().unwrap(); + let (tx, rx) = channel(); + lock.get_mut(&url).unwrap().push(tx); + rx + }; + rx.await.unwrap() + } else { + start_new_req(url, self.client.clone(), self.pending.clone()).await + } + } +} + +async fn start_new_req( + url: String, + client: reqwest::Client, + pending: Arc>>, +) -> Result { + let mut queue = Vec::with_capacity(64); + let (tx, rx) = channel(); + queue.push(tx); + { + let mut lock = pending.write().unwrap(); + lock.insert(url.clone(), queue); + } + tokio::spawn(perform_request(url, client, pending)); + rx.await.unwrap() +} + +async fn perform_request( + url: String, + client: reqwest::Client, + pending: Arc>>, +) { + let res = client.get(&url).send().await; + let res = match res { + Ok(res) => { + if res.status().is_success() { + Ok(Response { + status: res.status().as_u16(), + body: res.text().await.unwrap_or_default(), + }) + } else { + Err(Error { + status: Some(res.status().as_u16()), + message: res.text().await.unwrap_or_default(), + }) + } + } + Err(e) => { + Err(Error { + status: e.status().map(|s| s.as_u16()), + message: e.to_string(), + }) + } + }; + + let queue = pending.write().unwrap().remove(&url).unwrap(); + for tx in queue { + tx.send(res.clone()).unwrap(); + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use httpmock::prelude::*; + + #[tokio::test] + async fn single() { + let server = MockServer::start(); + + let hello_mock = server.mock(|when, then| { + when.method(GET) + .path("/translate") + .query_param("word", "hello"); + then.status(200) + .delay(Duration::from_millis(50)) + .header("content-type", "text/html") + .body("ohi"); + }); + + let client = DeduplicatorClient::default(); + let res = client.fetch(server.url("/translate?word=hello")).await; + + hello_mock.assert_hits(1); + assert_eq!( + res, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + } + + #[tokio::test] + async fn three() { + let server = MockServer::start(); + + let hello_mock = server.mock(|when, then| { + when.method(GET) + .path("/translate") + .query_param("word", "hello"); + then.status(200) + .delay(Duration::from_millis(5)) + .header("content-type", "text/html") + .body("ohi"); + }); + + let client = DeduplicatorClient::default(); + let (res1, res2, res3) = tokio::join!( + client.fetch(server.url("/translate?word=hello")), + client.fetch(server.url("/translate?word=hello")), + client.fetch(server.url("/translate?word=hello")) + ); + + hello_mock.assert_hits(1); + assert_eq!( + res1, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + assert_eq!( + res2, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + assert_eq!( + res3, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + } + + #[tokio::test] + async fn complex() { + let server = MockServer::start(); + + let hello_mock = server.mock(|when, then| { + when.method(GET) + .path("/translate") + .query_param("word", "hello"); + then.status(200) + .delay(Duration::from_millis(5)) + .header("content-type", "text/html") + .body("ohi"); + }); + let foo_mock = server.mock(|when, then| { + when.method(GET).path("/foo").query_param("word", "hello"); + then.status(200) + .delay(Duration::from_millis(5)) + .header("content-type", "text/html") + .body("bar"); + }); + + let client = DeduplicatorClient::default(); + let (res1, res2, res3, res4, res5) = tokio::join!( + client.fetch(server.url("/translate?word=hello")), + client.fetch(server.url("/translate?word=hello")), + client.fetch(server.url("/translate?word=hello")), + client.fetch(server.url("/foo?word=hello")), + client.fetch(server.url("/foo?word=hello")) + ); + + hello_mock.assert_hits(1); + foo_mock.assert_hits(1); + + assert_eq!( + res1, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + assert_eq!( + res2, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + assert_eq!( + res3, + Ok(Response { + status: 200, + body: String::from("ohi") + }) + ); + + assert_eq!( + res4, + Ok(Response { + status: 200, + body: String::from("bar") + }) + ); + assert_eq!( + res5, + Ok(Response { + status: 200, + body: String::from("bar") + }) + ); + } +}