Use new authenticator

This commit is contained in:
eraden 2023-08-13 15:31:05 +02:00
parent 7ec783651f
commit 8526a45e13
8 changed files with 94 additions and 83 deletions

1
Cargo.lock generated
View File

@ -2517,7 +2517,6 @@ name = "oswilno"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"actix", "actix",
"actix-jwt-session",
"actix-rt", "actix-rt",
"actix-web", "actix-web",
"actix-web-grants", "actix-web-grants",

View File

@ -1,5 +1,5 @@
use actix_web::{HttpMessage, FromRequest};
use actix_web::{dev::ServiceRequest, HttpResponse}; use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::HttpMessage;
use jsonwebtoken::{decode, DecodingKey, Validation}; use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc; use std::sync::Arc;
@ -20,6 +20,10 @@ pub enum Error {
InvalidSession, InvalidSession,
#[error("No http authentication header")] #[error("No http authentication header")]
NoAuthHeader, NoAuthHeader,
#[error("Failed to serialize claims")]
SerializeFailed,
#[error("Unable to write claims to storage")]
WriteFailed,
} }
impl actix_web::ResponseError for Error { impl actix_web::ResponseError for Error {
@ -35,13 +39,41 @@ impl actix_web::ResponseError for Error {
} }
} }
#[derive(Clone)]
pub struct Authenticated<T>(Arc<T>); pub struct Authenticated<T>(Arc<T>);
impl<T> std::ops::Deref for Authenticated<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl<T: Claims> FromRequest for Authenticated<T> {
type Error = actix_web::error::Error;
type Future = std::future::Ready<Result<Self, actix_web::Error>>;
fn from_request(
req: &actix_web::HttpRequest,
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
let value = req.extensions_mut().get::<Authenticated<T>>().map(Clone::clone);
std::future::ready(value.ok_or_else(|| Error::NotFound.into()))
}
}
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
pub trait TokenStorage { pub trait TokenStorage {
type ClaimsType: Claims; type ClaimsType: Claims;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>; async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
async fn set_by_jti(
self: Arc<Self>,
claims: Self::ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error>;
} }
struct Extractor; struct Extractor;
@ -53,7 +85,6 @@ impl Extractor {
jwt_validator: Arc<Validation>, jwt_validator: Arc<Validation>,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>, storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let authorisation_header = req let authorisation_header = req
.headers() .headers()
.get("Authorization") .get("Authorization")

View File

@ -3,10 +3,10 @@ use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Tr
use futures_util::future::LocalBoxFuture; use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{DecodingKey, Validation}; use jsonwebtoken::{DecodingKey, Validation};
use redis::AsyncCommands; use redis::AsyncCommands;
use std::future::{ready, Ready};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::future::{ready, Ready};
#[derive(Clone)] #[derive(Clone)]
pub struct RedisStorage<ClaimsType: Claims> { pub struct RedisStorage<ClaimsType: Claims> {
@ -39,6 +39,20 @@ where
.map_err(|_| Error::NotFound)?; .map_err(|_| Error::NotFound)?;
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed) bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
} }
async fn set_by_jti(
self: Arc<Self>,
claims: Self::ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error> {
let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize)
.await
.map_err(|_| Error::WriteFailed)?;
Ok(())
}
} }
pub struct RedisMiddleware<S, ClaimsType> pub struct RedisMiddleware<S, ClaimsType>
@ -80,18 +94,33 @@ where
} }
} }
#[derive(Debug,Clone)] #[derive(Clone)]
pub struct RedisMiddlewareFactory<ClaimsType: Claims> { pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
jwt_decoding_key: Arc<DecodingKey>, jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>, jwt_validator: Arc<Validation>,
storage: Arc<dyn TokenStorage>, storage: Arc<RedisStorage<ClaimsType>>,
_claims_type_marker: PhantomData<ClaimsType>, _claims_type_marker: PhantomData<ClaimsType>,
} }
impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
pub fn new(
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<RedisStorage<ClaimsType>>,
) -> Self {
Self {
jwt_decoding_key,
jwt_validator,
storage,
_claims_type_marker: Default::default(),
}
}
}
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType> impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType>
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
ClaimsType: DeserializeOwned + 'static, ClaimsType: Claims,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = actix_web::Error; type Error = actix_web::Error;
@ -110,7 +139,6 @@ where
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -127,6 +155,5 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn extract() { async fn extract() {}
}
} }

View File

@ -9,8 +9,8 @@ use sea_orm::ActiveValue::{NotSet, Set};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use oswilno_session::{Claims, Authenticated};
use oswilno_view::Layout; use oswilno_view::Layout;
use oswilno_session::UserSession;
pub fn mount(config: &mut ServiceConfig) { pub fn mount(config: &mut ServiceConfig) {
config.service( config.service(
@ -30,7 +30,9 @@ struct AllPartialParkingSpace {
} }
#[get("/all")] #[get("/all")]
async fn all_parking_spaces(db: Data<sea_orm::DatabaseConnection>) -> Layout<AllPartialParkingSpace> { async fn all_parking_spaces(
db: Data<sea_orm::DatabaseConnection>,
) -> Layout<AllPartialParkingSpace> {
let db = db.into_inner(); let db = db.into_inner();
let main = load_parking_spaces(db).await; let main = load_parking_spaces(db).await;
@ -105,7 +107,7 @@ struct CreateParkingSpace {
async fn create( async fn create(
db: Data<sea_orm::DatabaseConnection>, db: Data<sea_orm::DatabaseConnection>,
p: Form<CreateParkingSpace>, p: Form<CreateParkingSpace>,
session: UserSession, session: Authenticated<Claims>,
) -> HttpResponse { ) -> HttpResponse {
use oswilno_contract::parking_spaces::*; use oswilno_contract::parking_spaces::*;
let CreateParkingSpace { location, spot } = p.into_inner(); let CreateParkingSpace { location, spot } = p.into_inner();
@ -115,7 +117,7 @@ async fn create(
id: NotSet, id: NotSet,
location: Set(location.clone()), location: Set(location.clone()),
spot: Set(spot.map(|n| n as i32)), spot: Set(spot.map(|n| n as i32)),
account_id: Set(session.claims.account_id()), account_id: Set(session.account_id()),
..Default::default() ..Default::default()
}; };
if let Err(_e) = model.save(&*db).await { if let Err(_e) = model.save(&*db).await {

View File

@ -15,7 +15,6 @@ oswilno-config = { path = "../oswilno-config" }
oswilno-parking-space = { path = "../oswilno-parking-space" } oswilno-parking-space = { path = "../oswilno-parking-space" }
oswilno-session = { path = "../oswilno-session" } oswilno-session = { path = "../oswilno-session" }
oswilno-view = { path = "../oswilno-view" } oswilno-view = { path = "../oswilno-view" }
actix-jwt-session = { path = "../actix-jwt-session", features = ["use-redis"] }
redis = { version = "0.17" } redis = { version = "0.17" }
redis-async-pool = "0.2.4" redis-async-pool = "0.2.4"
sea-orm = { version = "0.11", features = ["postgres-array", "runtime-actix-rustls", "sqlx-postgres"] } sea-orm = { version = "0.11", features = ["postgres-array", "runtime-actix-rustls", "sqlx-postgres"] }

View File

@ -50,8 +50,7 @@ async fn main() -> std::io::Result<()> {
let session_config = session_config.clone(); let session_config = session_config.clone();
App::new() App::new()
.wrap(middleware::Logger::default()) .wrap(middleware::Logger::default())
.wrap(actix_jwt_session::RedisMiddleware::new()) .wrap(session_config.factory())
// .wrap(session_config.factory())
.app_data(Data::new(conn.clone())) .app_data(Data::new(conn.clone()))
.app_data(Data::new(redis.clone())) .app_data(Data::new(redis.clone()))
.app_data(Data::new(l10n.clone())) .app_data(Data::new(l10n.clone()))

View File

@ -17,7 +17,7 @@ garde = { version = "0.14.0", features = ["derive"] }
jsonwebtoken = "8.3.0" jsonwebtoken = "8.3.0"
oswilno-contract = { path = "../oswilno-contract" } oswilno-contract = { path = "../oswilno-contract" }
oswilno-view = { path = "../oswilno-view" } oswilno-view = { path = "../oswilno-view" }
actix-jwt-session = { path = "../actix-jwt-session" } actix-jwt-session = { path = "../actix-jwt-session", features = ["use-redis"] }
rand = "0.8.5" rand = "0.8.5"
redis = { version = "0.17" } redis = { version = "0.17" }
redis-async-pool = "0.2.4" redis-async-pool = "0.2.4"

View File

@ -1,14 +1,11 @@
use std::ops::Add; use std::ops::Add;
use std::sync::Arc; use std::sync::Arc;
use actix_jwt_authc::*; pub use actix_jwt_session::{Error, RedisMiddlewareFactory, Authenticated};
use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpResponse}; use actix_web::{get, post, HttpResponse};
use askama_actix::Template; use askama_actix::Template;
use autometrics::autometrics; use autometrics::autometrics;
use futures::channel::{mpsc, mpsc::Sender};
use futures::stream::Stream;
use futures::SinkExt;
use garde::Validate; use garde::Validate;
use jsonwebtoken::*; use jsonwebtoken::*;
use oswilno_view::{Errors, Lang, Layout, TranslationStorage}; use oswilno_view::{Errors, Lang, Layout, TranslationStorage};
@ -18,14 +15,13 @@ use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::ext::*; use time::ext::*;
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::sync::Mutex;
mod extract_session; mod extract_session;
mod hashing; mod hashing;
pub use oswilno_view::filters; pub use oswilno_view::filters;
pub type UserSession = Authenticated<Claims>; pub type UserSession = Claims;
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct JWTTtl(time::Duration); pub struct JWTTtl(time::Duration);
@ -51,6 +47,12 @@ pub struct Claims {
jwt_id: uuid::Uuid, jwt_id: uuid::Uuid,
} }
impl actix_jwt_session::Claims for Claims {
fn jti(&self) -> uuid::Uuid {
self.jwt_id
}
}
impl Claims { impl Claims {
pub fn account_id(&self) -> i32 { pub fn account_id(&self) -> i32 {
self.subject self.subject
@ -76,15 +78,13 @@ const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
#[derive(Clone)] #[derive(Clone)]
pub struct SessionConfigurator { pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>, jwt_ttl: Data<JWTTtl>,
invalidated_jwts_store: Data<InvalidatedJWTStore>,
encoding_key: Data<EncodingKey>, encoding_key: Data<EncodingKey>,
factory: AuthenticateMiddlewareFactory<Claims>, factory: RedisMiddlewareFactory<Claims>,
} }
impl SessionConfigurator { impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) { pub fn app_data(self, config: &mut ServiceConfig) {
config config
.app_data(self.invalidated_jwts_store)
.app_data(self.encoding_key) .app_data(self.encoding_key)
.app_data(self.jwt_ttl) .app_data(self.jwt_ttl)
.service(login) .service(login)
@ -97,7 +97,7 @@ impl SessionConfigurator {
.service(register_partial_view); .service(register_partial_view);
} }
pub fn factory(&self) -> AuthenticateMiddlewareFactory<Claims> { pub fn factory(&self) -> RedisMiddlewareFactory<Claims> {
self.factory.clone() self.factory.clone()
} }
@ -133,18 +133,13 @@ impl SessionConfigurator {
let jwt_ttl = JWTTtl(31.days()); let jwt_ttl = JWTTtl(31.days());
let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let validator = Validation::new(JWT_SIGNING_ALGO); let validator = Validation::new(JWT_SIGNING_ALGO);
let auth_middleware_settings = AuthenticateMiddlewareSettings { let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
jwt_decoding_key: jwt_signing_keys.decoding_key, Arc::new(jwt_signing_keys.decoding_key),
jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), Arc::new(validator),
jwt_validator: validator, Arc::new(actix_jwt_session::RedisStorage::new(redis)),
jwt_session_key: None, );
};
let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(redis);
let auth_middleware_factory =
AuthenticateMiddlewareFactory::<Claims>::new(stream, auth_middleware_settings.clone());
Self { Self {
invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()),
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
jwt_ttl: Data::new(jwt_ttl.clone()), jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory, factory: auth_middleware_factory,
@ -308,8 +303,8 @@ async fn login_inner(
#[autometrics] #[autometrics]
#[get("/session")] #[get("/session")]
async fn session_info(authenticated: UserSession) -> Result<HttpResponse, Error> { async fn session_info(authenticated: Authenticated<Claims>) -> Result<HttpResponse, Error> {
Ok(HttpResponse::Ok().json(authenticated)) Ok(HttpResponse::Ok().json(&*authenticated))
} }
#[autometrics] #[autometrics]
@ -320,7 +315,7 @@ async fn logout(
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
{ {
use redis::AsyncCommands; use redis::AsyncCommands;
let jwt_id = authenticated.claims.jwt_id; let jwt_id = authenticated.jwt_id;
if let Ok(mut conn) = redis.get().await { if let Ok(mut conn) = redis.get().await {
if conn.del::<_, String>(jwt_id.as_bytes()).await.is_err() {} if conn.del::<_, String>(jwt_id.as_bytes()).await.is_err() {}
} }
@ -506,48 +501,7 @@ async fn register_internal(
.json(EmptyResponse {})) .json(EmptyResponse {}))
} }
#[derive(Clone)] pub struct JwtSigningKeys {
struct InvalidatedJWTStore {
// store: Arc<DashSet<JWT>>,
redis: redis_async_pool::RedisPool,
tx: Arc<Mutex<Sender<InvalidatedTokensEvent>>>,
}
impl InvalidatedJWTStore {
/// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s
fn new_with_stream(
redis: redis_async_pool::RedisPool,
) -> (
InvalidatedJWTStore,
impl Stream<Item = InvalidatedTokensEvent>,
) {
// let invalidated = Arc::new(DashSet::new());
let (tx, rx) = mpsc::channel(100);
let tx_to_hold = Arc::new(Mutex::new(tx));
(
InvalidatedJWTStore {
// store: invalidated,
redis,
tx: tx_to_hold,
},
rx,
)
}
async fn add_to_invalidated(&self, authenticated: Authenticated<Claims>) {
// self.store.insert(authenticated.jwt.clone());
let mut tx = self.tx.lock().await;
if let Err(_e) = tx
.send(InvalidatedTokensEvent::Add(authenticated.jwt))
.await
{
#[cfg(feature = "tracing")]
error!(error = ?_e, "Failed to send update on adding to invalidated")
}
}
}
struct JwtSigningKeys {
encoding_key: EncodingKey, encoding_key: EncodingKey,
decoding_key: DecodingKey, decoding_key: DecodingKey,
} }