general purpose middleware
This commit is contained in:
parent
7fe4b285ce
commit
3d42e30061
@ -8,3 +8,4 @@ members = [
|
||||
'./crates/oswilno-actix-admin',
|
||||
'./crates/actix-jwt-session',
|
||||
]
|
||||
resolver = '2'
|
||||
|
@ -121,7 +121,9 @@
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use actix_web::{dev::ServiceRequest, HttpResponse};
|
||||
use std::future::{ready, Ready};
|
||||
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
||||
use actix_web::HttpResponse;
|
||||
use actix_web::{FromRequest, HttpMessage};
|
||||
use async_trait::async_trait;
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
||||
@ -600,6 +602,181 @@ impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SessionMiddleware<S, ClaimsType>
|
||||
where
|
||||
ClaimsType: Claims,
|
||||
{
|
||||
_claims_type_marker: std::marker::PhantomData<ClaimsType>,
|
||||
service: std::rc::Rc<S>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
|
||||
}
|
||||
|
||||
impl<S, B, ClaimsType> Service<ServiceRequest> for SessionMiddleware<S, ClaimsType>
|
||||
where
|
||||
ClaimsType: Claims,
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
|
||||
{
|
||||
type Response = ServiceResponse<B>;
|
||||
type Error = actix_web::Error;
|
||||
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
forward_ready!(service);
|
||||
|
||||
fn call(&self, req: ServiceRequest) -> Self::Future {
|
||||
use futures_lite::FutureExt;
|
||||
|
||||
let svc = self.service.clone();
|
||||
let jwt_decoding_key = self.jwt_decoding_key.clone();
|
||||
let jwt_encoding_key = self.jwt_encoding_key.clone();
|
||||
let algorithm = self.algorithm;
|
||||
let storage = self.storage.clone();
|
||||
let extractors = self.extractors.clone();
|
||||
|
||||
async move {
|
||||
let mut last_error = None;
|
||||
for extractor in extractors.iter() {
|
||||
match extractor
|
||||
.extract_jwt(
|
||||
&req,
|
||||
jwt_encoding_key.clone(),
|
||||
jwt_decoding_key.clone(),
|
||||
algorithm,
|
||||
storage.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => break,
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
}
|
||||
};
|
||||
}
|
||||
if let Some(e) = last_error {
|
||||
return Err(e)?;
|
||||
}
|
||||
let res = svc.call(req).await?;
|
||||
Ok(res)
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MiddlewareFactory<ClaimsType: Claims> {
|
||||
on_error: Option<Arc<dyn Fn(Error) -> Option<HttpResponse>>>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
|
||||
inner: Box<dyn MiddlewareFactoryImpl>,
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> MiddlewareFactory<ClaimsType> {
|
||||
pub fn build(
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
) -> MiddlewareFactoryBuilder<ClaimsType> {
|
||||
MiddlewareFactoryBuilder {
|
||||
jwt_decoding_key,
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
on_error: None,
|
||||
storage: None,
|
||||
extractors: Vec::new(),
|
||||
inner: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for MiddlewareFactory<ClaimsType>
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
|
||||
ClaimsType: Claims,
|
||||
{
|
||||
type Response = ServiceResponse<B>;
|
||||
type Error = actix_web::Error;
|
||||
type Transform = RedisMiddleware<S, ClaimsType>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
ready(Ok(RedisMiddleware {
|
||||
service: std::rc::Rc::new(service),
|
||||
storage: self.storage.clone(),
|
||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||
jwt_decoding_key: self.jwt_decoding_key.clone(),
|
||||
algorithm: self.algorithm,
|
||||
extractors: self.extractors.clone(),
|
||||
_claims_type_marker: PhantomData,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MiddlewareFactoryBuilder<ClaimsType: Claims> {
|
||||
on_error: Option<Arc<dyn Fn(Error) -> Option<HttpResponse>>>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: Option<SessionStorage<ClaimsType>>,
|
||||
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
|
||||
inner: Option<Box<dyn MiddlewareFactoryImpl>>,
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> MiddlewareFactoryBuilder<ClaimsType> {
|
||||
pub fn error_handler<F: Fn(Error) -> Option<HttpResponse> + 'static>(
|
||||
mut self,
|
||||
handler: F,
|
||||
) -> Self {
|
||||
self.on_error = Some(Arc::new(handler));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_storage<T: TokenStorage + 'static>(mut self, storage: Arc<T>) -> Self {
|
||||
self.storage = Some(SessionStorage::new(
|
||||
storage,
|
||||
self.jwt_encoding_key.clone(),
|
||||
self.algorithm.clone(),
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_extractor<E: SessionExtractor<ClaimsType>>(mut self, extractor: E) -> Self {
|
||||
self.extractors.push(Box::new(extractor));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_inner<Inner: MiddlewareFactoryImpl>(mut self, inner: Inner) -> Self {
|
||||
self.inner = Some(Box::new(inner));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> MiddlewareFactory<ClaimsType> {
|
||||
let Self {
|
||||
jwt_decoding_key,
|
||||
on_error,
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
storage,
|
||||
extractors,
|
||||
inner,
|
||||
} = self;
|
||||
MiddlewareFactory {
|
||||
jwt_decoding_key,
|
||||
jwt_encoding_key,
|
||||
on_error,
|
||||
algorithm,
|
||||
storage: storage.expect("No implementation of session storage"),
|
||||
extractors: Arc::new(extractors),
|
||||
inner: inner.expect("No implementation of session middleware factory"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait allowing to extract JWt token from [actix_web::dev::ServiceRequest]
|
||||
///
|
||||
/// Two extractor are implemented by default
|
||||
|
@ -197,3 +197,16 @@ where
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> MiddlewareFactoryBuilder<ClaimsType> {
|
||||
pub fn with_redis(self, pool: redis_async_pool::RedisPool) -> Self {
|
||||
let s = self.with_storage(Arc::new(RedisStorage::new(pool.clone())));
|
||||
s.with_inner(RedisMiddlewareFactory::new(
|
||||
self.jwt_encoding_key.clone(),
|
||||
self.jwt_decoding_key.clone(),
|
||||
self.algorithm.clone(),
|
||||
pool,
|
||||
extractors,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -525,7 +525,9 @@ async fn register_internal(
|
||||
p.clone()
|
||||
})?;
|
||||
let (login_taken, email_taken) = if let Some(query_result) = query_result {
|
||||
let Ok((login_taken, email_taken)): Result<(bool,bool), _> = query_result.try_get_many("", &["login_taken".into(), "email_taken".into()]) else {
|
||||
let Ok((login_taken, email_taken)): Result<(bool, bool), _> =
|
||||
query_result.try_get_many("", &["login_taken".into(), "email_taken".into()])
|
||||
else {
|
||||
tracing::warn!("Failed to fetch fields from query result while checking if account info exists in db");
|
||||
errors.push_global("Something went wrong");
|
||||
return Err(p);
|
||||
|
Loading…
Reference in New Issue
Block a user