2 kinds of session extractors

This commit is contained in:
eraden 2023-08-17 07:58:15 +02:00
parent 7069e58f1d
commit f20dbfc14b
4 changed files with 90 additions and 52 deletions

View File

@ -1,10 +1,11 @@
use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::{FromRequest, HttpMessage};
use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use std::sync::Arc;
use uuid::Uuid;
use async_trait::async_trait;
pub static HEADER_NAME: &str = "Authorization";
@ -185,8 +186,9 @@ impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
}
#[async_trait(?Send)]
pub trait Extractor {
async fn extract_jwt<ClaimsType: Claims>(
pub trait Extractor<ClaimsType: Claims>: Send + Sync + 'static {
async fn extract_jwt(
&self,
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
@ -194,27 +196,50 @@ pub trait Extractor {
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error>;
fn decode<ClaimsType: Claims>(value: &str,
fn decode(
&self,
value: &str,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
) -> Result<ClaimsType, Error> {
decode::<ClaimsType>(value, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|_e| {
decode::<ClaimsType>(value, &*jwt_decoding_key, &Validation::new(algorithm))
.map_err(|_e| {
// let error_message = e.to_string();
Error::InvalidSession
},
)
})
.map(|t| t.claims)
}
async fn validate(
&self,
claims: &ClaimsType,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let stored = storage
.0
.clone()
.get_from_jti(claims.jti())
.await
.map_err(|_| Error::InvalidSession)?;
if &stored != claims {
return Err(Error::InvalidSession);
}
Ok(())
}
}
pub struct CookieExtractor;
#[derive(Default)]
pub struct CookieExtractor<ClaimsType>(PhantomData<ClaimsType>);
impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
pub fn new() -> Self { Self(Default::default()) }
}
#[async_trait(?Send)]
impl Extractor for CookieExtractor {
async fn extract_jwt<ClaimsType: Claims>(
impl<ClaimsType: Claims> Extractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_jwt(
&self,
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
@ -222,16 +247,29 @@ impl Extractor for CookieExtractor {
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(cookie) = req.cookie(HEADER_NAME) else {return Ok(())};
let value = cookie.value();
let as_str = cookie.value();
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
}
}
pub struct HeaderExtractor;
#[derive(Default)]
pub struct HeaderExtractor<ClaimsType>(PhantomData<ClaimsType>);
impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
pub fn new() -> Self { Self(Default::default()) }
}
#[async_trait(?Send)]
impl Extractor for HeaderExtractor {
async fn extract_jwt<ClaimsType: Claims>(
impl<ClaimsType: Claims> Extractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_jwt(
&self,
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
@ -253,27 +291,10 @@ impl Extractor for HeaderExtractor {
.or_else(|| as_str.strip_prefix("bearer "))
.unwrap_or(as_str);
let decoded_claims =
decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|_e| {
// let error_message = e.to_string();
Error::InvalidSession
},
)?;
let stored = storage
.0
.clone()
.get_from_jti(decoded_claims.claims.jti())
.await
.map_err(|_| Error::InvalidSession)?;
if stored != decoded_claims.claims {
return Err(Error::InvalidSession);
}
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims.claims),
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});

View File

@ -89,6 +89,7 @@ where
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
extractors: Arc<Vec<Box<dyn Extractor<ClaimsType>>>>,
}
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
@ -110,16 +111,25 @@ where
let jwt_encoding_key = self.jwt_encoding_key.clone();
let algorithm = self.algorithm;
let storage = self.storage.clone();
let extractors = self.extractors.clone();
async move {
HeaderExtractor::extract_jwt(
let mut last_error = None;
for extractor in extractors.iter() {
match extractor.extract_jwt(
&req,
jwt_encoding_key,
jwt_decoding_key,
jwt_encoding_key.clone(),
jwt_decoding_key.clone(),
algorithm,
storage,
)
.await?;
storage.clone(),
).await {
Ok(_) => break,
Err(e) => { last_error = Some(e); },
};
}
if let Some(e) = last_error {
return Err(e)?;
}
let res = svc.call(req).await?;
Ok(res)
}
@ -133,6 +143,7 @@ pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
extractors: Arc<Vec<Box<dyn Extractor<ClaimsType>>>>,
_claims_type_marker: PhantomData<ClaimsType>,
}
@ -142,6 +153,7 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
pool: redis_async_pool::RedisPool,
extractors: Vec<Box<dyn Extractor<ClaimsType>>>,
) -> Self {
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
@ -150,6 +162,7 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
jwt_decoding_key,
algorithm,
storage: SessionStorage(storage),
extractors: Arc::new(extractors),
_claims_type_marker: Default::default(),
}
}
@ -177,6 +190,7 @@ where
jwt_encoding_key: self.jwt_encoding_key.clone(),
jwt_decoding_key: self.jwt_decoding_key.clone(),
algorithm: self.algorithm,
extractors: self.extractors.clone(),
_claims_type_marker: PhantomData,
}))
}

View File

@ -18,7 +18,7 @@ pub fn mount(config: &mut ServiceConfig) {
scope("/parking-spaces")
.service(form_show)
.service(all_parking_spaces)
.service(create)
.service(create),
);
}

View File

@ -1,7 +1,7 @@
use std::ops::Add;
use std::sync::Arc;
use actix_jwt_session::SessionStorage;
use actix_jwt_session::{SessionStorage, CookieExtractor, HeaderExtractor};
pub use actix_jwt_session::{Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpRequest, HttpResponse};
@ -133,6 +133,7 @@ impl SessionConfigurator {
Arc::new(jwt_signing_keys.decoding_key),
Algorithm::EdDSA,
redis,
vec![Box::new(CookieExtractor::<Claims>::new()), Box::new(HeaderExtractor::<Claims>::new())]
);
Self {
@ -295,7 +296,9 @@ async fn login_inner(
}
};
let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token).http_only(true).finish();
let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token)
.http_only(true)
.finish();
Ok(HttpResponse::Ok()
.append_header((
actix_jwt_session::HEADER_NAME,