2 kinds of session extractors
This commit is contained in:
parent
7069e58f1d
commit
f20dbfc14b
@ -1,10 +1,11 @@
|
|||||||
use actix_web::{dev::ServiceRequest, HttpResponse};
|
use actix_web::{dev::ServiceRequest, HttpResponse};
|
||||||
use actix_web::{FromRequest, HttpMessage};
|
use actix_web::{FromRequest, HttpMessage};
|
||||||
|
use async_trait::async_trait;
|
||||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
pub static HEADER_NAME: &str = "Authorization";
|
pub static HEADER_NAME: &str = "Authorization";
|
||||||
|
|
||||||
@ -185,8 +186,9 @@ impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
pub trait Extractor {
|
pub trait Extractor<ClaimsType: Claims>: Send + Sync + 'static {
|
||||||
async fn extract_jwt<ClaimsType: Claims>(
|
async fn extract_jwt(
|
||||||
|
&self,
|
||||||
req: &ServiceRequest,
|
req: &ServiceRequest,
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
jwt_encoding_key: Arc<EncodingKey>,
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
@ -194,27 +196,50 @@ pub trait Extractor {
|
|||||||
storage: SessionStorage<ClaimsType>,
|
storage: SessionStorage<ClaimsType>,
|
||||||
) -> Result<(), Error>;
|
) -> Result<(), Error>;
|
||||||
|
|
||||||
fn decode<ClaimsType: Claims>(value: &str,
|
fn decode(
|
||||||
|
&self,
|
||||||
|
value: &str,
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
|
|
||||||
) -> Result<ClaimsType, Error> {
|
) -> Result<ClaimsType, Error> {
|
||||||
decode::<ClaimsType>(value, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|
decode::<ClaimsType>(value, &*jwt_decoding_key, &Validation::new(algorithm))
|
||||||
|_e| {
|
.map_err(|_e| {
|
||||||
// let error_message = e.to_string();
|
// let error_message = e.to_string();
|
||||||
Error::InvalidSession
|
Error::InvalidSession
|
||||||
},
|
})
|
||||||
)
|
.map(|t| t.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate(
|
||||||
|
&self,
|
||||||
|
claims: &ClaimsType,
|
||||||
|
storage: SessionStorage<ClaimsType>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let stored = storage
|
||||||
|
.0
|
||||||
|
.clone()
|
||||||
|
.get_from_jti(claims.jti())
|
||||||
|
.await
|
||||||
|
.map_err(|_| Error::InvalidSession)?;
|
||||||
|
|
||||||
|
if &stored != claims {
|
||||||
|
return Err(Error::InvalidSession);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CookieExtractor;
|
#[derive(Default)]
|
||||||
|
pub struct CookieExtractor<ClaimsType>(PhantomData<ClaimsType>);
|
||||||
|
|
||||||
|
impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
|
||||||
|
pub fn new() -> Self { Self(Default::default()) }
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl Extractor for CookieExtractor {
|
impl<ClaimsType: Claims> Extractor<ClaimsType> for CookieExtractor<ClaimsType> {
|
||||||
async fn extract_jwt<ClaimsType: Claims>(
|
async fn extract_jwt(
|
||||||
|
&self,
|
||||||
req: &ServiceRequest,
|
req: &ServiceRequest,
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
jwt_encoding_key: Arc<EncodingKey>,
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
@ -222,16 +247,29 @@ impl Extractor for CookieExtractor {
|
|||||||
storage: SessionStorage<ClaimsType>,
|
storage: SessionStorage<ClaimsType>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let Some(cookie) = req.cookie(HEADER_NAME) else {return Ok(())};
|
let Some(cookie) = req.cookie(HEADER_NAME) else {return Ok(())};
|
||||||
let value = cookie.value();
|
let as_str = cookie.value();
|
||||||
|
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
||||||
|
self.validate(&decoded_claims, storage).await?;
|
||||||
|
req.extensions_mut().insert(Authenticated {
|
||||||
|
claims: Arc::new(decoded_claims),
|
||||||
|
jwt_encoding_key,
|
||||||
|
algorithm,
|
||||||
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct HeaderExtractor;
|
#[derive(Default)]
|
||||||
|
pub struct HeaderExtractor<ClaimsType>(PhantomData<ClaimsType>);
|
||||||
|
|
||||||
|
impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
|
||||||
|
pub fn new() -> Self { Self(Default::default()) }
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl Extractor for HeaderExtractor {
|
impl<ClaimsType: Claims> Extractor<ClaimsType> for HeaderExtractor<ClaimsType> {
|
||||||
async fn extract_jwt<ClaimsType: Claims>(
|
async fn extract_jwt(
|
||||||
|
&self,
|
||||||
req: &ServiceRequest,
|
req: &ServiceRequest,
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
jwt_encoding_key: Arc<EncodingKey>,
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
@ -253,27 +291,10 @@ impl Extractor for HeaderExtractor {
|
|||||||
.or_else(|| as_str.strip_prefix("bearer "))
|
.or_else(|| as_str.strip_prefix("bearer "))
|
||||||
.unwrap_or(as_str);
|
.unwrap_or(as_str);
|
||||||
|
|
||||||
let decoded_claims =
|
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
||||||
decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|
self.validate(&decoded_claims, storage).await?;
|
||||||
|_e| {
|
|
||||||
// let error_message = e.to_string();
|
|
||||||
Error::InvalidSession
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let stored = storage
|
|
||||||
.0
|
|
||||||
.clone()
|
|
||||||
.get_from_jti(decoded_claims.claims.jti())
|
|
||||||
.await
|
|
||||||
.map_err(|_| Error::InvalidSession)?;
|
|
||||||
|
|
||||||
if stored != decoded_claims.claims {
|
|
||||||
return Err(Error::InvalidSession);
|
|
||||||
}
|
|
||||||
|
|
||||||
req.extensions_mut().insert(Authenticated {
|
req.extensions_mut().insert(Authenticated {
|
||||||
claims: Arc::new(decoded_claims.claims),
|
claims: Arc::new(decoded_claims),
|
||||||
jwt_encoding_key,
|
jwt_encoding_key,
|
||||||
algorithm,
|
algorithm,
|
||||||
});
|
});
|
||||||
|
@ -89,6 +89,7 @@ where
|
|||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
storage: SessionStorage<ClaimsType>,
|
storage: SessionStorage<ClaimsType>,
|
||||||
|
extractors: Arc<Vec<Box<dyn Extractor<ClaimsType>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
|
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
|
||||||
@ -110,16 +111,25 @@ where
|
|||||||
let jwt_encoding_key = self.jwt_encoding_key.clone();
|
let jwt_encoding_key = self.jwt_encoding_key.clone();
|
||||||
let algorithm = self.algorithm;
|
let algorithm = self.algorithm;
|
||||||
let storage = self.storage.clone();
|
let storage = self.storage.clone();
|
||||||
|
let extractors = self.extractors.clone();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
HeaderExtractor::extract_jwt(
|
let mut last_error = None;
|
||||||
|
for extractor in extractors.iter() {
|
||||||
|
match extractor.extract_jwt(
|
||||||
&req,
|
&req,
|
||||||
jwt_encoding_key,
|
jwt_encoding_key.clone(),
|
||||||
jwt_decoding_key,
|
jwt_decoding_key.clone(),
|
||||||
algorithm,
|
algorithm,
|
||||||
storage,
|
storage.clone(),
|
||||||
)
|
).await {
|
||||||
.await?;
|
Ok(_) => break,
|
||||||
|
Err(e) => { last_error = Some(e); },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(e) = last_error {
|
||||||
|
return Err(e)?;
|
||||||
|
}
|
||||||
let res = svc.call(req).await?;
|
let res = svc.call(req).await?;
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
@ -133,6 +143,7 @@ pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
|
|||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
storage: SessionStorage<ClaimsType>,
|
storage: SessionStorage<ClaimsType>,
|
||||||
|
extractors: Arc<Vec<Box<dyn Extractor<ClaimsType>>>>,
|
||||||
_claims_type_marker: PhantomData<ClaimsType>,
|
_claims_type_marker: PhantomData<ClaimsType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,6 +153,7 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
|||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
pool: redis_async_pool::RedisPool,
|
pool: redis_async_pool::RedisPool,
|
||||||
|
extractors: Vec<Box<dyn Extractor<ClaimsType>>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
|
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
|
||||||
|
|
||||||
@ -150,6 +162,7 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
|||||||
jwt_decoding_key,
|
jwt_decoding_key,
|
||||||
algorithm,
|
algorithm,
|
||||||
storage: SessionStorage(storage),
|
storage: SessionStorage(storage),
|
||||||
|
extractors: Arc::new(extractors),
|
||||||
_claims_type_marker: Default::default(),
|
_claims_type_marker: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,6 +190,7 @@ where
|
|||||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||||
jwt_decoding_key: self.jwt_decoding_key.clone(),
|
jwt_decoding_key: self.jwt_decoding_key.clone(),
|
||||||
algorithm: self.algorithm,
|
algorithm: self.algorithm,
|
||||||
|
extractors: self.extractors.clone(),
|
||||||
_claims_type_marker: PhantomData,
|
_claims_type_marker: PhantomData,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ pub fn mount(config: &mut ServiceConfig) {
|
|||||||
scope("/parking-spaces")
|
scope("/parking-spaces")
|
||||||
.service(form_show)
|
.service(form_show)
|
||||||
.service(all_parking_spaces)
|
.service(all_parking_spaces)
|
||||||
.service(create)
|
.service(create),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::ops::Add;
|
use std::ops::Add;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use actix_jwt_session::SessionStorage;
|
use actix_jwt_session::{SessionStorage, CookieExtractor, HeaderExtractor};
|
||||||
pub use actix_jwt_session::{Error, RedisMiddlewareFactory};
|
pub use actix_jwt_session::{Error, RedisMiddlewareFactory};
|
||||||
use actix_web::web::{Data, Form, ServiceConfig};
|
use actix_web::web::{Data, Form, ServiceConfig};
|
||||||
use actix_web::{get, post, HttpRequest, HttpResponse};
|
use actix_web::{get, post, HttpRequest, HttpResponse};
|
||||||
@ -133,6 +133,7 @@ impl SessionConfigurator {
|
|||||||
Arc::new(jwt_signing_keys.decoding_key),
|
Arc::new(jwt_signing_keys.decoding_key),
|
||||||
Algorithm::EdDSA,
|
Algorithm::EdDSA,
|
||||||
redis,
|
redis,
|
||||||
|
vec![Box::new(CookieExtractor::<Claims>::new()), Box::new(HeaderExtractor::<Claims>::new())]
|
||||||
);
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -295,7 +296,9 @@ async fn login_inner(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token).http_only(true).finish();
|
let cookie = actix_web::cookie::Cookie::build(actix_jwt_session::HEADER_NAME, &bearer_token)
|
||||||
|
.http_only(true)
|
||||||
|
.finish();
|
||||||
Ok(HttpResponse::Ok()
|
Ok(HttpResponse::Ok()
|
||||||
.append_header((
|
.append_header((
|
||||||
actix_jwt_session::HEADER_NAME,
|
actix_jwt_session::HEADER_NAME,
|
||||||
|
Loading…
Reference in New Issue
Block a user