From 8526a45e13b21cf50d144ee7d48c87ed983cfbc2 Mon Sep 17 00:00:00 2001 From: eraden Date: Sun, 13 Aug 2023 15:31:05 +0200 Subject: [PATCH] Use new authenticator --- Cargo.lock | 1 - crates/actix-jwt-session/src/lib.rs | 35 +++++++- crates/actix-jwt-session/src/redis_adapter.rs | 41 +++++++-- crates/oswilno-parking-space/src/lib.rs | 10 ++- crates/oswilno-server/Cargo.toml | 1 - crates/oswilno-server/src/main.rs | 3 +- crates/oswilno-session/Cargo.toml | 2 +- crates/oswilno-session/src/lib.rs | 84 +++++-------------- 8 files changed, 94 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2b6798..8687ff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2517,7 +2517,6 @@ name = "oswilno" version = "0.1.0" dependencies = [ "actix", - "actix-jwt-session", "actix-rt", "actix-web", "actix-web-grants", diff --git a/crates/actix-jwt-session/src/lib.rs b/crates/actix-jwt-session/src/lib.rs index 745f6fd..64f34bd 100644 --- a/crates/actix-jwt-session/src/lib.rs +++ b/crates/actix-jwt-session/src/lib.rs @@ -1,5 +1,5 @@ +use actix_web::{HttpMessage, FromRequest}; use actix_web::{dev::ServiceRequest, HttpResponse}; -use actix_web::HttpMessage; use jsonwebtoken::{decode, DecodingKey, Validation}; use serde::{de::DeserializeOwned, Serialize}; use std::sync::Arc; @@ -20,6 +20,10 @@ pub enum Error { InvalidSession, #[error("No http authentication header")] NoAuthHeader, + #[error("Failed to serialize claims")] + SerializeFailed, + #[error("Unable to write claims to storage")] + WriteFailed, } impl actix_web::ResponseError for Error { @@ -35,13 +39,41 @@ impl actix_web::ResponseError for Error { } } +#[derive(Clone)] pub struct Authenticated(Arc); +impl std::ops::Deref for Authenticated { + type Target = T; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl FromRequest for Authenticated { + type Error = actix_web::error::Error; + type Future = std::future::Ready>; + + fn from_request( + req: &actix_web::HttpRequest, + _payload: &mut actix_web::dev::Payload, + ) -> Self::Future { + let value = req.extensions_mut().get::>().map(Clone::clone); + std::future::ready(value.ok_or_else(|| Error::NotFound.into())) + } +} + #[async_trait::async_trait(?Send)] pub trait TokenStorage { type ClaimsType: Claims; async fn get_from_jti(self: Arc, jti: uuid::Uuid) -> Result; + + async fn set_by_jti( + self: Arc, + claims: Self::ClaimsType, + exp: std::time::Duration, + ) -> Result<(), Error>; } struct Extractor; @@ -53,7 +85,6 @@ impl Extractor { jwt_validator: Arc, storage: Arc>, ) -> Result<(), Error> { - let authorisation_header = req .headers() .get("Authorization") diff --git a/crates/actix-jwt-session/src/redis_adapter.rs b/crates/actix-jwt-session/src/redis_adapter.rs index 335c572..b0762c0 100644 --- a/crates/actix-jwt-session/src/redis_adapter.rs +++ b/crates/actix-jwt-session/src/redis_adapter.rs @@ -3,10 +3,10 @@ use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Tr use futures_util::future::LocalBoxFuture; use jsonwebtoken::{DecodingKey, Validation}; use redis::AsyncCommands; +use std::future::{ready, Ready}; use std::marker::PhantomData; use std::rc::Rc; use std::sync::Arc; -use std::future::{ready, Ready}; #[derive(Clone)] pub struct RedisStorage { @@ -39,6 +39,20 @@ where .map_err(|_| Error::NotFound)?; bincode::deserialize(&val).map_err(|_| Error::RecordMalformed) } + + async fn set_by_jti( + self: Arc, + claims: Self::ClaimsType, + exp: std::time::Duration, + ) -> Result<(), Error> { + let pool = self.pool.clone(); + let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?; + let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?; + conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize) + .await + .map_err(|_| Error::WriteFailed)?; + Ok(()) + } } pub struct RedisMiddleware @@ -80,18 +94,33 @@ where } } -#[derive(Debug,Clone)] +#[derive(Clone)] pub struct RedisMiddlewareFactory { jwt_decoding_key: Arc, jwt_validator: Arc, - storage: Arc, + storage: Arc>, _claims_type_marker: PhantomData, } +impl RedisMiddlewareFactory { + pub fn new( + jwt_decoding_key: Arc, + jwt_validator: Arc, + storage: Arc>, + ) -> Self { + Self { + jwt_decoding_key, + jwt_validator, + storage, + _claims_type_marker: Default::default(), + } + } +} + impl Transform for RedisMiddlewareFactory where S: Service, Error = actix_web::Error> + 'static, - ClaimsType: DeserializeOwned + 'static, + ClaimsType: Claims, { type Response = ServiceResponse; type Error = actix_web::Error; @@ -110,7 +139,6 @@ where } } - #[cfg(test)] mod tests { use super::*; @@ -127,6 +155,5 @@ mod tests { } #[tokio::test] - async fn extract() { - } + async fn extract() {} } diff --git a/crates/oswilno-parking-space/src/lib.rs b/crates/oswilno-parking-space/src/lib.rs index 09a2e71..da9fb7f 100644 --- a/crates/oswilno-parking-space/src/lib.rs +++ b/crates/oswilno-parking-space/src/lib.rs @@ -9,8 +9,8 @@ use sea_orm::ActiveValue::{NotSet, Set}; use std::collections::HashMap; use std::sync::Arc; +use oswilno_session::{Claims, Authenticated}; use oswilno_view::Layout; -use oswilno_session::UserSession; pub fn mount(config: &mut ServiceConfig) { config.service( @@ -30,7 +30,9 @@ struct AllPartialParkingSpace { } #[get("/all")] -async fn all_parking_spaces(db: Data) -> Layout { +async fn all_parking_spaces( + db: Data, +) -> Layout { let db = db.into_inner(); let main = load_parking_spaces(db).await; @@ -105,7 +107,7 @@ struct CreateParkingSpace { async fn create( db: Data, p: Form, - session: UserSession, + session: Authenticated, ) -> HttpResponse { use oswilno_contract::parking_spaces::*; let CreateParkingSpace { location, spot } = p.into_inner(); @@ -115,7 +117,7 @@ async fn create( id: NotSet, location: Set(location.clone()), spot: Set(spot.map(|n| n as i32)), - account_id: Set(session.claims.account_id()), + account_id: Set(session.account_id()), ..Default::default() }; if let Err(_e) = model.save(&*db).await { diff --git a/crates/oswilno-server/Cargo.toml b/crates/oswilno-server/Cargo.toml index 52a5fe6..0072ee6 100644 --- a/crates/oswilno-server/Cargo.toml +++ b/crates/oswilno-server/Cargo.toml @@ -15,7 +15,6 @@ oswilno-config = { path = "../oswilno-config" } oswilno-parking-space = { path = "../oswilno-parking-space" } oswilno-session = { path = "../oswilno-session" } oswilno-view = { path = "../oswilno-view" } -actix-jwt-session = { path = "../actix-jwt-session", features = ["use-redis"] } redis = { version = "0.17" } redis-async-pool = "0.2.4" sea-orm = { version = "0.11", features = ["postgres-array", "runtime-actix-rustls", "sqlx-postgres"] } diff --git a/crates/oswilno-server/src/main.rs b/crates/oswilno-server/src/main.rs index a34f76d..2a6a394 100644 --- a/crates/oswilno-server/src/main.rs +++ b/crates/oswilno-server/src/main.rs @@ -50,8 +50,7 @@ async fn main() -> std::io::Result<()> { let session_config = session_config.clone(); App::new() .wrap(middleware::Logger::default()) - .wrap(actix_jwt_session::RedisMiddleware::new()) - // .wrap(session_config.factory()) + .wrap(session_config.factory()) .app_data(Data::new(conn.clone())) .app_data(Data::new(redis.clone())) .app_data(Data::new(l10n.clone())) diff --git a/crates/oswilno-session/Cargo.toml b/crates/oswilno-session/Cargo.toml index 4a0010f..87ce57a 100644 --- a/crates/oswilno-session/Cargo.toml +++ b/crates/oswilno-session/Cargo.toml @@ -17,7 +17,7 @@ garde = { version = "0.14.0", features = ["derive"] } jsonwebtoken = "8.3.0" oswilno-contract = { path = "../oswilno-contract" } oswilno-view = { path = "../oswilno-view" } -actix-jwt-session = { path = "../actix-jwt-session" } +actix-jwt-session = { path = "../actix-jwt-session", features = ["use-redis"] } rand = "0.8.5" redis = { version = "0.17" } redis-async-pool = "0.2.4" diff --git a/crates/oswilno-session/src/lib.rs b/crates/oswilno-session/src/lib.rs index 0a524ab..7071496 100644 --- a/crates/oswilno-session/src/lib.rs +++ b/crates/oswilno-session/src/lib.rs @@ -1,14 +1,11 @@ use std::ops::Add; use std::sync::Arc; -use actix_jwt_authc::*; +pub use actix_jwt_session::{Error, RedisMiddlewareFactory, Authenticated}; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpResponse}; use askama_actix::Template; use autometrics::autometrics; -use futures::channel::{mpsc, mpsc::Sender}; -use futures::stream::Stream; -use futures::SinkExt; use garde::Validate; use jsonwebtoken::*; use oswilno_view::{Errors, Lang, Layout, TranslationStorage}; @@ -18,14 +15,13 @@ use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; -use tokio::sync::Mutex; mod extract_session; mod hashing; pub use oswilno_view::filters; -pub type UserSession = Authenticated; +pub type UserSession = Claims; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); @@ -51,6 +47,12 @@ pub struct Claims { jwt_id: uuid::Uuid, } +impl actix_jwt_session::Claims for Claims { + fn jti(&self) -> uuid::Uuid { + self.jwt_id + } +} + impl Claims { pub fn account_id(&self) -> i32 { self.subject @@ -76,15 +78,13 @@ const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, - invalidated_jwts_store: Data, encoding_key: Data, - factory: AuthenticateMiddlewareFactory, + factory: RedisMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config - .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) @@ -97,7 +97,7 @@ impl SessionConfigurator { .service(register_partial_view); } - pub fn factory(&self) -> AuthenticateMiddlewareFactory { + pub fn factory(&self) -> RedisMiddlewareFactory { self.factory.clone() } @@ -133,18 +133,13 @@ impl SessionConfigurator { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); - let auth_middleware_settings = AuthenticateMiddlewareSettings { - jwt_decoding_key: jwt_signing_keys.decoding_key, - jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), - jwt_validator: validator, - jwt_session_key: None, - }; - let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(redis); - let auth_middleware_factory = - AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); + let auth_middleware_factory = RedisMiddlewareFactory::::new( + Arc::new(jwt_signing_keys.decoding_key), + Arc::new(validator), + Arc::new(actix_jwt_session::RedisStorage::new(redis)), + ); Self { - invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, @@ -308,8 +303,8 @@ async fn login_inner( #[autometrics] #[get("/session")] -async fn session_info(authenticated: UserSession) -> Result { - Ok(HttpResponse::Ok().json(authenticated)) +async fn session_info(authenticated: Authenticated) -> Result { + Ok(HttpResponse::Ok().json(&*authenticated)) } #[autometrics] @@ -320,7 +315,7 @@ async fn logout( ) -> Result { { use redis::AsyncCommands; - let jwt_id = authenticated.claims.jwt_id; + let jwt_id = authenticated.jwt_id; if let Ok(mut conn) = redis.get().await { if conn.del::<_, String>(jwt_id.as_bytes()).await.is_err() {} } @@ -506,48 +501,7 @@ async fn register_internal( .json(EmptyResponse {})) } -#[derive(Clone)] -struct InvalidatedJWTStore { - // store: Arc>, - redis: redis_async_pool::RedisPool, - tx: Arc>>, -} - -impl InvalidatedJWTStore { - /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s - fn new_with_stream( - redis: redis_async_pool::RedisPool, - ) -> ( - InvalidatedJWTStore, - impl Stream, - ) { - // let invalidated = Arc::new(DashSet::new()); - let (tx, rx) = mpsc::channel(100); - let tx_to_hold = Arc::new(Mutex::new(tx)); - ( - InvalidatedJWTStore { - // store: invalidated, - redis, - tx: tx_to_hold, - }, - rx, - ) - } - - async fn add_to_invalidated(&self, authenticated: Authenticated) { - // self.store.insert(authenticated.jwt.clone()); - let mut tx = self.tx.lock().await; - if let Err(_e) = tx - .send(InvalidatedTokensEvent::Add(authenticated.jwt)) - .await - { - #[cfg(feature = "tracing")] - error!(error = ?_e, "Failed to send update on adding to invalidated") - } - } -} - -struct JwtSigningKeys { +pub struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, }