general purpose middleware

This commit is contained in:
Adrian Woźniak 2023-08-29 16:15:36 +02:00
parent 7fe4b285ce
commit 3d42e30061
4 changed files with 195 additions and 2 deletions

View File

@ -8,3 +8,4 @@ members = [
'./crates/oswilno-actix-admin',
'./crates/actix-jwt-session',
]
resolver = '2'

View File

@ -121,7 +121,9 @@
//! }
//! ```
use actix_web::{dev::ServiceRequest, HttpResponse};
use std::future::{ready, Ready};
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
use actix_web::HttpResponse;
use actix_web::{FromRequest, HttpMessage};
use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
@ -600,6 +602,181 @@ impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
}
}
pub struct SessionMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
{
_claims_type_marker: std::marker::PhantomData<ClaimsType>,
service: std::rc::Rc<S>,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
}
impl<S, B, ClaimsType> Service<ServiceRequest> for SessionMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
use futures_lite::FutureExt;
let svc = self.service.clone();
let jwt_decoding_key = self.jwt_decoding_key.clone();
let jwt_encoding_key = self.jwt_encoding_key.clone();
let algorithm = self.algorithm;
let storage = self.storage.clone();
let extractors = self.extractors.clone();
async move {
let mut last_error = None;
for extractor in extractors.iter() {
match extractor
.extract_jwt(
&req,
jwt_encoding_key.clone(),
jwt_decoding_key.clone(),
algorithm,
storage.clone(),
)
.await
{
Ok(_) => break,
Err(e) => {
last_error = Some(e);
}
};
}
if let Some(e) = last_error {
return Err(e)?;
}
let res = svc.call(req).await?;
Ok(res)
}
.boxed_local()
}
}
pub struct MiddlewareFactory<ClaimsType: Claims> {
on_error: Option<Arc<dyn Fn(Error) -> Option<HttpResponse>>>,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
inner: Box<dyn MiddlewareFactoryImpl>,
}
impl<ClaimsType: Claims> MiddlewareFactory<ClaimsType> {
pub fn build(
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
) -> MiddlewareFactoryBuilder<ClaimsType> {
MiddlewareFactoryBuilder {
jwt_decoding_key,
jwt_encoding_key,
algorithm,
on_error: None,
storage: None,
extractors: Vec::new(),
inner: None,
}
}
}
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for MiddlewareFactory<ClaimsType>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
ClaimsType: Claims,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Transform = RedisMiddleware<S, ClaimsType>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RedisMiddleware {
service: std::rc::Rc::new(service),
storage: self.storage.clone(),
jwt_encoding_key: self.jwt_encoding_key.clone(),
jwt_decoding_key: self.jwt_decoding_key.clone(),
algorithm: self.algorithm,
extractors: self.extractors.clone(),
_claims_type_marker: PhantomData,
}))
}
}
pub struct MiddlewareFactoryBuilder<ClaimsType: Claims> {
on_error: Option<Arc<dyn Fn(Error) -> Option<HttpResponse>>>,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Option<SessionStorage<ClaimsType>>,
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
inner: Option<Box<dyn MiddlewareFactoryImpl>>,
}
impl<ClaimsType: Claims> MiddlewareFactoryBuilder<ClaimsType> {
pub fn error_handler<F: Fn(Error) -> Option<HttpResponse> + 'static>(
mut self,
handler: F,
) -> Self {
self.on_error = Some(Arc::new(handler));
self
}
pub fn with_storage<T: TokenStorage + 'static>(mut self, storage: Arc<T>) -> Self {
self.storage = Some(SessionStorage::new(
storage,
self.jwt_encoding_key.clone(),
self.algorithm.clone(),
));
self
}
pub fn with_extractor<E: SessionExtractor<ClaimsType>>(mut self, extractor: E) -> Self {
self.extractors.push(Box::new(extractor));
self
}
pub fn with_inner<Inner: MiddlewareFactoryImpl>(mut self, inner: Inner) -> Self {
self.inner = Some(Box::new(inner));
self
}
pub fn build(self) -> MiddlewareFactory<ClaimsType> {
let Self {
jwt_decoding_key,
on_error,
jwt_encoding_key,
algorithm,
storage,
extractors,
inner,
} = self;
MiddlewareFactory {
jwt_decoding_key,
jwt_encoding_key,
on_error,
algorithm,
storage: storage.expect("No implementation of session storage"),
extractors: Arc::new(extractors),
inner: inner.expect("No implementation of session middleware factory"),
}
}
}
/// Trait allowing to extract JWt token from [actix_web::dev::ServiceRequest]
///
/// Two extractor are implemented by default

View File

@ -197,3 +197,16 @@ where
}))
}
}
impl<ClaimsType: Claims> MiddlewareFactoryBuilder<ClaimsType> {
pub fn with_redis(self, pool: redis_async_pool::RedisPool) -> Self {
let s = self.with_storage(Arc::new(RedisStorage::new(pool.clone())));
s.with_inner(RedisMiddlewareFactory::new(
self.jwt_encoding_key.clone(),
self.jwt_decoding_key.clone(),
self.algorithm.clone(),
pool,
extractors,
))
}
}

View File

@ -525,7 +525,9 @@ async fn register_internal(
p.clone()
})?;
let (login_taken, email_taken) = if let Some(query_result) = query_result {
let Ok((login_taken, email_taken)): Result<(bool,bool), _> = query_result.try_get_many("", &["login_taken".into(), "email_taken".into()]) else {
let Ok((login_taken, email_taken)): Result<(bool, bool), _> =
query_result.try_get_many("", &["login_taken".into(), "email_taken".into()])
else {
tracing::warn!("Failed to fetch fields from query result while checking if account info exists in db");
errors.push_global("Something went wrong");
return Err(p);