Use new authenticator

This commit is contained in:
eraden 2023-08-14 07:38:31 +02:00
parent 932665a767
commit e34a306668
6 changed files with 58 additions and 24 deletions

2
Cargo.lock generated
View File

@ -186,9 +186,11 @@ dependencies = [
"futures",
"futures-lite",
"futures-util",
"garde",
"jsonwebtoken",
"redis",
"redis-async-pool",
"ring",
"serde",
"thiserror",
"tokio 1.30.0",

View File

@ -22,3 +22,11 @@ serde = { version = "1.0.183", features = ["derive"] }
thiserror = "1.0.44"
tokio = { version = "1.30.0", features = ["full"] }
uuid = { version = "1.4.1", features = ["v4"] }
[[test]]
name = "ensure_redis_flow"
path = "./tests/ensure_redis_flow.rs"
[dev-dependencies]
garde = "0.14.0"
ring = "0.16.20"

View File

@ -1,6 +1,6 @@
use actix_web::{HttpMessage, FromRequest};
use actix_web::{dev::ServiceRequest, HttpResponse};
use jsonwebtoken::{decode, DecodingKey, Validation};
use actix_web::{FromRequest, HttpMessage};
use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
@ -42,13 +42,23 @@ impl actix_web::ResponseError for Error {
#[derive(Clone)]
#[cfg_attr(feature = "serde-transparent", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde-transparent", serde(transparent))]
pub struct Authenticated<T>(Arc<T>);
pub struct Authenticated<T> {
claims: Arc<T>,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
}
impl<T> std::ops::Deref for Authenticated<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.0
&*self.claims
}
}
impl<T: Claims> Authenticated<T> {
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key)
}
}
@ -60,7 +70,10 @@ impl<T: Claims> FromRequest for Authenticated<T> {
req: &actix_web::HttpRequest,
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
let value = req.extensions_mut().get::<Authenticated<T>>().map(Clone::clone);
let value = req
.extensions_mut()
.get::<Authenticated<T>>()
.map(Clone::clone);
std::future::ready(value.ok_or_else(|| Error::NotFound.into()))
}
}
@ -83,19 +96,22 @@ struct Extractor;
impl Extractor {
async fn extract_bearer_jwt<ClaimsType: Claims>(
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
) -> Result<(), Error> {
let authorisation_header = req
let Some(authorisation_header) = req
.headers()
.get("Authorization")
.ok_or(Error::NoAuthHeader)?;
else {
return Ok(())
};
let as_str = authorisation_header
.to_str()
.map_err(|_| Error::NoAuthHeader)?;
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &*jwt_validator)
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm))
.map_err(|_e| {
// let error_message = e.to_string();
Error::InvalidSession
@ -111,7 +127,9 @@ impl Extractor {
}
req.extensions_mut()
.insert(Authenticated(Arc::new(decoded_claims.claims)));
.insert(Authenticated {
claims: Arc::new(decoded_claims.claims),
});
Ok(())
}
}

View File

@ -61,9 +61,10 @@ where
{
_claims_type_marker: std::marker::PhantomData<ClaimsType>,
service: Rc<S>,
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<RedisStorage<ClaimsType>>,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
}
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
@ -82,11 +83,12 @@ where
let svc = self.service.clone();
let jwt_decoding_key = self.jwt_decoding_key.clone();
let validation = self.jwt_validator.clone();
let jwt_encoding_key = self.jwt_encoding_key.clone();
let algorithm = self.algorithm;
let storage = self.storage.clone();
async move {
Extractor::extract_bearer_jwt(&req, jwt_decoding_key, validation, storage).await?;
Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?;
let res = svc.call(req).await?;
Ok(res)
}
@ -96,22 +98,26 @@ where
#[derive(Clone)]
pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<RedisStorage<ClaimsType>>,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
_claims_type_marker: PhantomData<ClaimsType>,
}
impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
pub fn new(
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<RedisStorage<ClaimsType>>,
algorithm: Algorithm,
pool: redis_async_pool::RedisPool,
) -> Self {
Self {
jwt_encoding_key,
jwt_decoding_key,
jwt_validator,
storage,
algorithm,
storage: RedisStorage::new(
pool),
_claims_type_marker: Default::default(),
}
}

View File

@ -9,7 +9,7 @@ use sea_orm::ActiveValue::{NotSet, Set};
use std::collections::HashMap;
use std::sync::Arc;
use oswilno_session::{Claims, Authenticated};
use oswilno_session::{Authenticated, Claims};
use oswilno_view::Layout;
pub fn mount(config: &mut ServiceConfig) {

View File

@ -1,7 +1,7 @@
use std::ops::Add;
use std::sync::Arc;
pub use actix_jwt_session::{Error, RedisMiddlewareFactory, Authenticated};
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpResponse};
use askama_actix::Template;