User exractor

This commit is contained in:
eraden 2024-02-11 07:08:46 +01:00
parent 694d89bcd5
commit 226cc764cb
8 changed files with 310 additions and 55 deletions

View File

@ -1,11 +1,11 @@
[workspace] [workspace]
members = [ members = [
"crates/squadron-space",
"crates/squadron-api", "crates/squadron-api",
"crates/squadron-bg",
"crates/squadron-beat", "crates/squadron-beat",
"crates/squadron-bg",
"crates/squadron-contract", "crates/squadron-contract",
"crates/squadron-plug", "crates/squadron-plug",
"crates/squadron-space",
] ]
resolver = "2" resolver = "2"

View File

@ -5,51 +5,50 @@ edition = "2021"
[dependencies] [dependencies]
actix = "0.13.1" actix = "0.13.1"
actix-jwt-session = "1.0.2"
actix-web = { version = "4.4.1", default-features = false, features = ["rustls", "actix-macros", "macros", "experimental-io-uring"] } actix-web = { version = "4.4.1", default-features = false, features = ["rustls", "actix-macros", "macros", "experimental-io-uring"] }
async-trait = "0.1.77" async-trait = "0.1.77"
base64 = "0.21.7"
basen = "0.1.0"
bincode = "1.3.3" bincode = "1.3.3"
chrono = { version = "0.4.32", default-features = false, features = ["clock", "serde"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "deref", "deref_mut", "from"] }
dotenv = "0.15.0"
entities = { workspace = true }
figment = { version = "0.10.14", features = ["env", "toml"] }
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
hmac = { version = "0.12.1", features = ["std"] }
http-api-isahc-client = { version = "0.2.2", features = ["with-sleep-via-tokio"] }
humantime = "2.1.0"
oauth2 = "4.4.2"
oauth2-amazon = "0.2.0"
oauth2-client = "0.2.0"
oauth2-core = "0.2.0"
oauth2-github = "0.2.0"
oauth2-gitlab = "0.2.0"
oauth2-google = "0.2.0"
oauth2-signin = "0.2.0"
password-hash = "0.5.0"
rand = { version = "0.8.5", features = ["serde"] }
reqwest = { version = "0.11.23", default-features = false, features = ["rustls", "tokio-rustls", "tokio-socks", "multipart"] }
rumqttc = { version = "0.23.0", features = ["use-rustls"] } rumqttc = { version = "0.23.0", features = ["use-rustls"] }
rust-s3 = { version = "0.33.0", features = ["tokio-rustls-tls", "futures-util", "futures-io"] } rust-s3 = { version = "0.33.0", features = ["tokio-rustls-tls", "futures-util", "futures-io"] }
sea-orm = { version = "0.12.11", features = ["postgres-array", "sqlx-all"] } sea-orm = { version = "0.12.11", features = ["postgres-array", "sqlx-all"] }
serde = "1.0.195"
serde_json = { version = "1.0.111", features = ["raw_value", "alloc"] }
tokio = { version = "1.35.1", features = ["full"] }
squadron-contract = { workspace = true }
uuid = { version = "1.7.0", features = ["v4", "serde"] }
entities = { workspace = true }
figment = { version = "0.10.14", features = ["env", "toml"] }
serde-aux = "4.4.0"
actix-jwt-session = "1.0.2"
sqlx = { version = "0.7.3", features = ["runtime-tokio"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "serde", "serde_json", "chrono", "json"] }
serde-email = { version = "3.0.1", features = ["all"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "deref", "deref_mut", "from"] }
thiserror = "1.0.56"
rand = { version = "0.8.5", features = ["serde"] }
oauth2 = "4.4.2"
oauth2-google = "0.2.0"
oauth2-github = "0.2.0"
oauth2-gitlab = "0.2.0"
oauth2-amazon = "0.2.0"
oauth2-client = "0.2.0"
oauth2-signin = "0.2.0"
oauth2-core = "0.2.0"
reqwest = { version = "0.11.23", default-features = false, features = ["rustls", "tokio-rustls", "tokio-socks", "multipart"] }
http-api-isahc-client = { version = "0.2.2", features = ["with-sleep-via-tokio"] }
dotenv = "0.15.0"
chrono = { version = "0.4.32", default-features = false, features = ["clock", "serde"] }
validators = { version = "0.25.3", default-features = false, features = ["email", "derive", "all-validators"] }
sentry = { version = "0.32.1", default-features = false, features = ["tokio", "rustls", "tracing", "isahc", "sentry-backtrace", "sentry-log", "sentry-contexts", "backtrace", "panic"] } sentry = { version = "0.32.1", default-features = false, features = ["tokio", "rustls", "tracing", "isahc", "sentry-backtrace", "sentry-log", "sentry-contexts", "backtrace", "panic"] }
basen = "0.1.0" serde = "1.0.195"
base64 = "0.21.7" serde-aux = "4.4.0"
hmac = { version = "0.12.1", features = ["std"] } serde-email = { version = "3.0.1", features = ["all"] }
serde_json = { version = "1.0.111", features = ["raw_value", "alloc"] }
sha2 = "0.10.8" sha2 = "0.10.8"
humantime = "2.1.0" sqlx = { version = "0.7.3", features = ["runtime-tokio"] }
password-hash = "0.5.0" squadron-contract = { workspace = true }
tracing-test = { version = "0.2.4", features = ["no-env-filter"] } thiserror = "1.0.56"
tokio = { version = "1.35.1", features = ["full"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["serde", "serde_json", "chrono", "json"] }
uuid = { version = "1.7.0", features = ["v4", "serde"] }
validators = { version = "0.25.3", default-features = false, features = ["email", "derive", "all-validators"] }
[dev-dependencies] [dev-dependencies]
tracing-test = "0.2.4" tracing-test = { version = "0.2.4", features = ["no-env-filter"] }

View File

@ -1,26 +1,281 @@
use actix_jwt_session::Authenticated;
use actix_web::web::Data; use actix_web::web::Data;
use actix_web::FromRequest; use actix_web::{FromRequest, HttpMessage};
use derive_more::Deref; use derive_more::Deref;
use entities::prelude::Users; use entities::prelude::Users;
use entities::users::*; use entities::users::*;
use futures_util::future::LocalBoxFuture; use futures_util::future::LocalBoxFuture;
use futures_util::FutureExt;
use sea_orm::EntityTrait;
use crate::utils::uidb::Unauthorized; use crate::session::AppClaims;
use crate::DatabaseConnection; use crate::DatabaseConnection;
#[derive(Debug, Deref)] #[derive(Debug, Deref, serde::Serialize)]
#[serde(transparent)]
pub struct RequireUser(pub Model); pub struct RequireUser(pub Model);
impl FromRequest for RequireUser { impl FromRequest for RequireUser {
type Error = Unauthorized; type Error = crate::models::Error;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request( fn from_request(
req: &actix_web::HttpRequest, req: &actix_web::HttpRequest,
payload: &mut actix_web::dev::Payload, _payload: &mut actix_web::dev::Payload,
) -> Self::Future { ) -> Self::Future {
let db = req.app_data::<Data<DatabaseConnection>>().cloned(); tracing::debug!("Start user from req");
let db = req
.app_data::<Data<DatabaseConnection>>()
.cloned()
.ok_or(crate::models::Error::DatabaseError);
tracing::debug!("DB exists");
let r = req
.extensions()
.get::<Authenticated<AppClaims>>()
.ok_or(crate::models::Error::UserRequired)
.map(|s| s.account_id());
tracing::debug!("Authenticated exists");
todo!() async move {
let db = db?;
let id = r?;
tracing::debug!("Looking for user {id:?}");
Users::find_by_id(id)
.one(&**db)
.await
.map_err(|e| {
tracing::error!("Failed to connect to db: {e}");
crate::models::Error::DatabaseError
})?
.ok_or_else(|| {
tracing::debug!("User {id:?} does not exists");
crate::models::Error::UserRequired
})
.map(Self)
}
.boxed_local()
}
}
#[cfg(test)]
mod tests {
use actix_jwt_session::{
Hashing, JwtTtl, RefreshTtl, SessionMiddlewareFactory, JWT_COOKIE_NAME, JWT_HEADER_NAME,
REFRESH_COOKIE_NAME, REFRESH_HEADER_NAME,
};
use actix_web::body::to_bytes;
use actix_web::web::Data;
use actix_web::{test, App};
use reqwest::{Method, StatusCode};
use sea_orm::Database;
use squadron_contract::deadpool_redis;
use tracing_test::traced_test;
use uuid::Uuid;
use super::*;
use crate::session;
macro_rules! create_app {
($app: ident, $session_storage: ident, $db: ident) => {
std::env::set_var(
"DATABASE_URL",
"postgres://postgres@0.0.0.0:5432/squadron_test",
);
let redis = deadpool_redis::Config::from_url("redis://0.0.0.0:6379")
.create_pool(Some(deadpool_redis::Runtime::Tokio1))
.expect("Can't connect to redis");
let $db: sea_orm::prelude::DatabaseConnection =
Database::connect("postgres://postgres@0.0.0.0:5432/squadron_test")
.await
.expect("Failed to connect to database");
let ($session_storage, factory) =
SessionMiddlewareFactory::<session::AppClaims>::build_ed_dsa()
.with_redis_pool(redis.clone())
// Check if header "Authorization" exists and contains Bearer with encoded JWT
.with_jwt_header(JWT_HEADER_NAME)
// Check if cookie JWT exists and contains encoded JWT
.with_jwt_cookie(JWT_COOKIE_NAME)
.with_refresh_header(REFRESH_HEADER_NAME)
// Check if cookie JWT exists and contains encoded JWT
.with_refresh_cookie(REFRESH_COOKIE_NAME)
.with_jwt_json(&["access_token"])
.finish();
let $db = Data::new($db.clone());
let $app = test::init_service(
App::new()
.app_data(Data::new($session_storage.clone()))
.app_data($db.clone())
.app_data(Data::new(redis))
.wrap(actix_web::middleware::NormalizePath::trim())
.wrap(actix_web::middleware::Logger::default())
.wrap(factory)
.service(test_path),
)
.await;
};
}
async fn create_user(
db: Data<DatabaseConnection>,
user_name: &str,
pass: &str,
) -> entities::users::Model {
use entities::users::*;
use sea_orm::*;
if let Ok(Some(user)) = Users::find()
.filter(Column::Email.eq(format!("{user_name}@example.com")))
.one(&**db)
.await
{
return user;
}
let pass = Hashing::encrypt(pass).unwrap();
Users::insert(ActiveModel {
password: Set(pass),
email: Set(Some(format!("{user_name}@example.com"))),
display_name: Set(user_name.to_string()),
username: Set(Uuid::new_v4().to_string()),
first_name: Set("".to_string()),
last_name: Set("".to_string()),
last_location: Set("".to_string()),
created_location: Set("".to_string()),
is_password_autoset: Set(false),
token: Set(Uuid::new_v4().to_string()),
billing_address_country: Set("".to_string()),
user_timezone: Set("UTC".to_string()),
last_login_ip: Set("0.0.0.0".to_string()),
last_login_medium: Set("None".to_string()),
last_logout_ip: Set("0.0.0.0".to_string()),
last_login_uagent: Set("test".to_string()),
is_active: Set(true),
avatar: Set("".to_string()),
..Default::default()
})
.exec_with_returning(&**db)
.await
.unwrap()
}
#[actix_web::get("/test")]
async fn test_path(user: RequireUser) -> actix_web::HttpResponse {
actix_web::HttpResponse::Ok().json(serde_json::json!({ "user": user }))
}
#[traced_test]
#[actix_web::test]
async fn valid() {
create_app!(app, session, db);
let user = create_user(db, "valid_extract_user", "QWEqwwe123@#$").await;
let pair = session
.store(
AppClaims {
expiration_time: (chrono::Utc::now() + chrono::Duration::days(100))
.timestamp_millis(),
issued_at: 0,
subject: "999999999".into(),
audience: session::Audience::Web,
jwt_id: Uuid::new_v4(),
account_id: user.id,
not_before: 0,
},
JwtTtl::new(actix_jwt_session::Duration::days(9999)),
RefreshTtl::new(actix_jwt_session::Duration::days(9999)),
)
.await
.unwrap();
let req = test::TestRequest::default()
// .insert_header(ContentType::json())
.insert_header((JWT_HEADER_NAME, pair.jwt.encode().unwrap()))
.uri("/test")
.method(Method::GET)
.to_request();
let resp = test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body();
let json: serde_json::Value =
serde_json::from_slice(&to_bytes(body).await.unwrap()[..]).unwrap();
assert_eq!(json.get("user").is_some(), true);
}
#[traced_test]
#[actix_web::test]
async fn bad_account_id() {
create_app!(app, session, db);
let _user = create_user(db, "valid_extract_user", "QWEqwwe123@#$").await;
let pair = session
.store(
AppClaims {
expiration_time: (chrono::Utc::now() + chrono::Duration::days(100))
.timestamp_millis(),
issued_at: 0,
subject: "999999999".into(),
audience: session::Audience::Web,
jwt_id: Uuid::new_v4(),
account_id: Uuid::new_v4(),
not_before: 0,
},
JwtTtl::new(actix_jwt_session::Duration::days(9999)),
RefreshTtl::new(actix_jwt_session::Duration::days(9999)),
)
.await
.unwrap();
let req = test::TestRequest::default()
.insert_header((JWT_HEADER_NAME, pair.jwt.encode().unwrap()))
.uri("/test")
.method(Method::GET)
.to_request();
let resp = test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body();
let json: serde_json::Value =
serde_json::from_slice(&to_bytes(body).await.unwrap()[..]).unwrap();
assert_eq!(json.get("error").is_some(), true);
}
#[traced_test]
#[actix_web::test]
async fn no_token() {
create_app!(app, session, db);
let user = create_user(db, "valid_extract_user", "QWEqwwe123@#$").await;
let _pair = session
.store(
AppClaims {
expiration_time: (chrono::Utc::now() + chrono::Duration::days(100))
.timestamp_millis(),
issued_at: 0,
subject: "999999999".into(),
audience: session::Audience::Web,
jwt_id: Uuid::new_v4(),
account_id: user.id,
not_before: 0,
},
JwtTtl::new(actix_jwt_session::Duration::days(9999)),
RefreshTtl::new(actix_jwt_session::Duration::days(9999)),
)
.await
.unwrap();
let req = test::TestRequest::default()
// .insert_header((JWT_HEADER_NAME, pair.jwt.encode().unwrap()))
.uri("/test")
.method(Method::GET)
.to_request();
let resp = test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body();
let json: serde_json::Value =
serde_json::from_slice(&to_bytes(body).await.unwrap()[..]).unwrap();
assert_eq!(json.get("error").is_some(), true);
} }
} }

View File

@ -63,3 +63,6 @@ async fn single_api_token(
} }
} }
} }
#[cfg(test)]
mod single_tests {}

View File

@ -2,7 +2,6 @@
name = "squadron-beat" name = "squadron-beat"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]

View File

@ -5,12 +5,12 @@ edition = "2021"
[dependencies] [dependencies]
bincode = "1.3.3" bincode = "1.3.3"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "deref", "deref_mut", "constructor"] }
serde = { version = "1.0.195", default-features = false }
uuid = { version = "1.7.0", features = ["v4", "serde"] }
rumqttc = { version = "0.23.0", features = ["use-rustls"] }
thiserror = "1.0.56"
deadpool-redis = { version = "0.14.0", features = ["serde"] }
redis = { version = "0.24", features = ["serde"] }
chrono = { version = "0.4.31", default-features = false, features = ["serde", "std", "libc", "pure-rust-locales", "clock"] } chrono = { version = "0.4.31", default-features = false, features = ["serde", "std", "libc", "pure-rust-locales", "clock"] }
deadpool-redis = { version = "0.14.0", features = ["serde"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "deref", "deref_mut", "constructor"] }
redis = { version = "0.24", features = ["serde"] }
rumqttc = { version = "0.23.0", features = ["use-rustls"] }
serde = { version = "1.0.195", default-features = false }
serde_json = "1.0.111" serde_json = "1.0.111"
thiserror = "1.0.56"
uuid = { version = "1.7.0", features = ["v4", "serde"] }

View File

@ -15,5 +15,5 @@ rumqttc = "0.23.0"
rust-s3 = { version = "0.33.0", features = ["tokio-rustls-tls", "futures-util", "futures-io"] } rust-s3 = { version = "0.33.0", features = ["tokio-rustls-tls", "futures-util", "futures-io"] }
serde = "1.0.195" serde = "1.0.195"
serde_json = "1.0.111" serde_json = "1.0.111"
tokio = { version = "1.35.1", features = ["full"] }
squadron-contract = { workspace = true } squadron-contract = { workspace = true }
tokio = { version = "1.35.1", features = ["full"] }

View File

@ -2,7 +2,6 @@
name = "squadron-space" name = "squadron-space"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]