diff --git a/crates/actix-jwt-session/src/lib.rs b/crates/actix-jwt-session/src/lib.rs index 44413f1..d198914 100644 --- a/crates/actix-jwt-session/src/lib.rs +++ b/crates/actix-jwt-session/src/lib.rs @@ -1,10 +1,11 @@ use actix_web::{dev::ServiceRequest, HttpResponse}; use actix_web::{FromRequest, HttpMessage}; +use async_trait::async_trait; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use serde::{de::DeserializeOwned, Serialize}; +use std::marker::PhantomData; use std::sync::Arc; use uuid::Uuid; -use async_trait::async_trait; pub static HEADER_NAME: &str = "Authorization"; @@ -185,8 +186,9 @@ impl SessionStorage { } #[async_trait(?Send)] -pub trait Extractor { - async fn extract_jwt( +pub trait Extractor: Send + Sync + 'static { + async fn extract_jwt( + &self, req: &ServiceRequest, jwt_encoding_key: Arc, jwt_decoding_key: Arc, @@ -194,27 +196,50 @@ pub trait Extractor { storage: SessionStorage, ) -> Result<(), Error>; - fn decode(value: &str, - - jwt_decoding_key: Arc, + fn decode( + &self, + value: &str, + jwt_decoding_key: Arc, algorithm: Algorithm, + ) -> Result { + decode::(value, &*jwt_decoding_key, &Validation::new(algorithm)) + .map_err(|_e| { + // let error_message = e.to_string(); + Error::InvalidSession + }) + .map(|t| t.claims) + } - ) -> Result { - decode::(value, &*jwt_decoding_key, &Validation::new(algorithm)).map_err( - |_e| { - // let error_message = e.to_string(); - Error::InvalidSession - }, - ) + async fn validate( + &self, + claims: &ClaimsType, + storage: SessionStorage, + ) -> Result<(), Error> { + let stored = storage + .0 + .clone() + .get_from_jti(claims.jti()) + .await + .map_err(|_| Error::InvalidSession)?; + if &stored != claims { + return Err(Error::InvalidSession); + } + Ok(()) } } -pub struct CookieExtractor; +#[derive(Default)] +pub struct CookieExtractor(PhantomData); + +impl CookieExtractor { + pub fn new() -> Self { Self(Default::default()) } +} #[async_trait(?Send)] -impl Extractor for CookieExtractor { - async fn extract_jwt( +impl Extractor for CookieExtractor { + async fn extract_jwt( + &self, req: &ServiceRequest, jwt_encoding_key: Arc, jwt_decoding_key: Arc, @@ -222,16 +247,29 @@ impl Extractor for CookieExtractor { storage: SessionStorage, ) -> Result<(), Error> { let Some(cookie) = req.cookie(HEADER_NAME) else {return Ok(())}; - let value = cookie.value(); + let as_str = cookie.value(); + let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?; + self.validate(&decoded_claims, storage).await?; + req.extensions_mut().insert(Authenticated { + claims: Arc::new(decoded_claims), + jwt_encoding_key, + algorithm, + }); Ok(()) } } -pub struct HeaderExtractor; +#[derive(Default)] +pub struct HeaderExtractor(PhantomData); + +impl HeaderExtractor { + pub fn new() -> Self { Self(Default::default()) } +} #[async_trait(?Send)] -impl Extractor for HeaderExtractor { - async fn extract_jwt( +impl Extractor for HeaderExtractor { + async fn extract_jwt( + &self, req: &ServiceRequest, jwt_encoding_key: Arc, jwt_decoding_key: Arc, @@ -253,27 +291,10 @@ impl Extractor for HeaderExtractor { .or_else(|| as_str.strip_prefix("bearer ")) .unwrap_or(as_str); - let decoded_claims = - decode::(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err( - |_e| { - // let error_message = e.to_string(); - Error::InvalidSession - }, - )?; - - let stored = storage - .0 - .clone() - .get_from_jti(decoded_claims.claims.jti()) - .await - .map_err(|_| Error::InvalidSession)?; - - if stored != decoded_claims.claims { - return Err(Error::InvalidSession); - } - + let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?; + self.validate(&decoded_claims, storage).await?; req.extensions_mut().insert(Authenticated { - claims: Arc::new(decoded_claims.claims), + claims: Arc::new(decoded_claims), jwt_encoding_key, algorithm, }); diff --git a/crates/actix-jwt-session/src/redis_adapter.rs b/crates/actix-jwt-session/src/redis_adapter.rs index 14a60d0..4696498 100644 --- a/crates/actix-jwt-session/src/redis_adapter.rs +++ b/crates/actix-jwt-session/src/redis_adapter.rs @@ -89,6 +89,7 @@ where jwt_decoding_key: Arc, algorithm: Algorithm, storage: SessionStorage, + extractors: Arc>>>, } impl Service for RedisMiddleware @@ -110,16 +111,25 @@ where let jwt_encoding_key = self.jwt_encoding_key.clone(); let algorithm = self.algorithm; let storage = self.storage.clone(); + let extractors = self.extractors.clone(); async move { - HeaderExtractor::extract_jwt( - &req, - jwt_encoding_key, - jwt_decoding_key, - algorithm, - storage, - ) - .await?; + let mut last_error = None; + for extractor in extractors.iter() { + match extractor.extract_jwt( + &req, + jwt_encoding_key.clone(), + jwt_decoding_key.clone(), + algorithm, + storage.clone(), + ).await { + Ok(_) => break, + Err(e) => { last_error = Some(e); }, + }; + } + if let Some(e) = last_error { + return Err(e)?; + } let res = svc.call(req).await?; Ok(res) } @@ -133,6 +143,7 @@ pub struct RedisMiddlewareFactory { jwt_decoding_key: Arc, algorithm: Algorithm, storage: SessionStorage, + extractors: Arc>>>, _claims_type_marker: PhantomData, } @@ -142,6 +153,7 @@ impl RedisMiddlewareFactory { jwt_decoding_key: Arc, algorithm: Algorithm, pool: redis_async_pool::RedisPool, + extractors: Vec>>, ) -> Self { let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm)); @@ -150,6 +162,7 @@ impl RedisMiddlewareFactory { jwt_decoding_key, algorithm, storage: SessionStorage(storage), + extractors: Arc::new(extractors), _claims_type_marker: Default::default(), } } @@ -177,6 +190,7 @@ where jwt_encoding_key: self.jwt_encoding_key.clone(), jwt_decoding_key: self.jwt_decoding_key.clone(), algorithm: self.algorithm, + extractors: self.extractors.clone(), _claims_type_marker: PhantomData, })) } diff --git a/crates/oswilno-parking-space/src/lib.rs b/crates/oswilno-parking-space/src/lib.rs index 86fc363..de501a5 100644 --- a/crates/oswilno-parking-space/src/lib.rs +++ b/crates/oswilno-parking-space/src/lib.rs @@ -18,7 +18,7 @@ pub fn mount(config: &mut ServiceConfig) { scope("/parking-spaces") .service(form_show) .service(all_parking_spaces) - .service(create) + .service(create), ); } @@ -126,7 +126,7 @@ async fn load_parking_spaces(db: Arc) -> AllPartialParkingSp async fn form_show(req: HttpRequest, session: Authenticated) -> HttpResponse { let session = session.into(); let body = ParkingSpaceFormPartial { -..Default::default() + ..Default::default() }; let main = Main { body, diff --git a/crates/oswilno-session/src/lib.rs b/crates/oswilno-session/src/lib.rs index e4b7acc..87230b8 100644 --- a/crates/oswilno-session/src/lib.rs +++ b/crates/oswilno-session/src/lib.rs @@ -1,7 +1,7 @@ use std::ops::Add; use std::sync::Arc; -use actix_jwt_session::SessionStorage; +use actix_jwt_session::{SessionStorage, CookieExtractor, HeaderExtractor}; pub use actix_jwt_session::{Error, RedisMiddlewareFactory}; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpRequest, HttpResponse}; @@ -133,6 +133,7 @@ impl SessionConfigurator { Arc::new(jwt_signing_keys.decoding_key), Algorithm::EdDSA, redis, + vec![Box::new(CookieExtractor::::new()), Box::new(HeaderExtractor::::new())] ); Self { @@ -295,7 +296,9 @@ async fn login_inner( } }; - let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token).http_only(true).finish(); + let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token) + .http_only(true) + .finish(); Ok(HttpResponse::Ok() .append_header(( actix_jwt_session::HEADER_NAME,