Valid solution

This commit is contained in:
eraden 2023-07-21 22:35:18 +02:00
commit 9bfc723bcb
3 changed files with 289 additions and 0 deletions

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
/target
# Added by cargo
#
# already existing elements were commented out
#/target
/Cargo.lock

13
Cargo.toml Normal file
View File

@ -0,0 +1,13 @@
[package]
name = "deduplicator"
version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = "0.1.72"
tokio = { version = "1.29.1", features = ["full"] }
reqwest = "0.11.18"
futures-channel = "0.3.28"
[dev-dependencies]
httpmock = "0.6.8"

267
src/lib.rs Normal file
View File

@ -0,0 +1,267 @@
#![allow(dead_code)]
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use async_trait::async_trait;
use futures_channel::oneshot::{channel, Sender};
#[async_trait]
trait Deduplicator: Default + Send + Sync + 'static {
async fn fetch(&self, url: String) -> Result<Response, Error>;
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
struct Response {
status: u16,
body: String,
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
struct Error {
status: Option<u16>,
message: String,
}
type ReqRes = Result<Response, Error>;
type ReqResQueue = Vec<Sender<ReqRes>>;
pub struct DeduplicatorClient {
pending: Arc<RwLock<HashMap<String, ReqResQueue>>>,
client: reqwest::Client,
}
impl Default for DeduplicatorClient {
fn default() -> Self {
Self {
pending: Arc::new(RwLock::new(HashMap::with_capacity(64))),
client: reqwest::ClientBuilder::default().build().unwrap(),
}
}
}
#[async_trait]
impl Deduplicator for DeduplicatorClient {
async fn fetch(&self, url: String) -> ReqRes {
let is_pending = {
let lock = self.pending.read().unwrap();
lock.contains_key(&url)
};
if is_pending {
let rx = {
let mut lock = self.pending.write().unwrap();
let (tx, rx) = channel();
lock.get_mut(&url).unwrap().push(tx);
rx
};
rx.await.unwrap()
} else {
start_new_req(url, self.client.clone(), self.pending.clone()).await
}
}
}
async fn start_new_req(
url: String,
client: reqwest::Client,
pending: Arc<RwLock<HashMap<String, ReqResQueue>>>,
) -> Result<Response, Error> {
let mut queue = Vec::with_capacity(64);
let (tx, rx) = channel();
queue.push(tx);
{
let mut lock = pending.write().unwrap();
lock.insert(url.clone(), queue);
}
tokio::spawn(perform_request(url, client, pending));
rx.await.unwrap()
}
async fn perform_request(
url: String,
client: reqwest::Client,
pending: Arc<RwLock<HashMap<String, ReqResQueue>>>,
) {
let res = client.get(&url).send().await;
let res = match res {
Ok(res) => {
if res.status().is_success() {
Ok(Response {
status: res.status().as_u16(),
body: res.text().await.unwrap_or_default(),
})
} else {
Err(Error {
status: Some(res.status().as_u16()),
message: res.text().await.unwrap_or_default(),
})
}
}
Err(e) => {
Err(Error {
status: e.status().map(|s| s.as_u16()),
message: e.to_string(),
})
}
};
let queue = pending.write().unwrap().remove(&url).unwrap();
for tx in queue {
tx.send(res.clone()).unwrap();
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use httpmock::prelude::*;
#[tokio::test]
async fn single() {
let server = MockServer::start();
let hello_mock = server.mock(|when, then| {
when.method(GET)
.path("/translate")
.query_param("word", "hello");
then.status(200)
.delay(Duration::from_millis(50))
.header("content-type", "text/html")
.body("ohi");
});
let client = DeduplicatorClient::default();
let res = client.fetch(server.url("/translate?word=hello")).await;
hello_mock.assert_hits(1);
assert_eq!(
res,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
}
#[tokio::test]
async fn three() {
let server = MockServer::start();
let hello_mock = server.mock(|when, then| {
when.method(GET)
.path("/translate")
.query_param("word", "hello");
then.status(200)
.delay(Duration::from_millis(5))
.header("content-type", "text/html")
.body("ohi");
});
let client = DeduplicatorClient::default();
let (res1, res2, res3) = tokio::join!(
client.fetch(server.url("/translate?word=hello")),
client.fetch(server.url("/translate?word=hello")),
client.fetch(server.url("/translate?word=hello"))
);
hello_mock.assert_hits(1);
assert_eq!(
res1,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
assert_eq!(
res2,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
assert_eq!(
res3,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
}
#[tokio::test]
async fn complex() {
let server = MockServer::start();
let hello_mock = server.mock(|when, then| {
when.method(GET)
.path("/translate")
.query_param("word", "hello");
then.status(200)
.delay(Duration::from_millis(5))
.header("content-type", "text/html")
.body("ohi");
});
let foo_mock = server.mock(|when, then| {
when.method(GET).path("/foo").query_param("word", "hello");
then.status(200)
.delay(Duration::from_millis(5))
.header("content-type", "text/html")
.body("bar");
});
let client = DeduplicatorClient::default();
let (res1, res2, res3, res4, res5) = tokio::join!(
client.fetch(server.url("/translate?word=hello")),
client.fetch(server.url("/translate?word=hello")),
client.fetch(server.url("/translate?word=hello")),
client.fetch(server.url("/foo?word=hello")),
client.fetch(server.url("/foo?word=hello"))
);
hello_mock.assert_hits(1);
foo_mock.assert_hits(1);
assert_eq!(
res1,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
assert_eq!(
res2,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
assert_eq!(
res3,
Ok(Response {
status: 200,
body: String::from("ohi")
})
);
assert_eq!(
res4,
Ok(Response {
status: 200,
body: String::from("bar")
})
);
assert_eq!(
res5,
Ok(Response {
status: 200,
body: String::from("bar")
})
);
}
}