diff --git a/crates/actix-jwt-session/src/lib.rs b/crates/actix-jwt-session/src/lib.rs index d9662dc..920e532 100644 --- a/crates/actix-jwt-session/src/lib.rs +++ b/crates/actix-jwt-session/src/lib.rs @@ -1,6 +1,6 @@ use actix_web::{dev::ServiceRequest, HttpResponse}; use actix_web::{FromRequest, HttpMessage}; -use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm}; +use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use serde::{de::DeserializeOwned, Serialize}; use std::sync::Arc; @@ -58,7 +58,11 @@ impl std::ops::Deref for Authenticated { impl Authenticated { pub fn encode(&self) -> Result { - encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key) + encode( + &jsonwebtoken::Header::new(self.algorithm), + &*self.claims, + &*self.jwt_encoding_key, + ) } } @@ -79,7 +83,7 @@ impl FromRequest for Authenticated { } #[async_trait::async_trait(?Send)] -pub trait TokenStorage { +pub trait TokenStorage: Send + Sync { type ClaimsType: Claims; async fn get_from_jti(self: Arc, jti: uuid::Uuid) -> Result; @@ -89,6 +93,48 @@ pub trait TokenStorage { claims: Self::ClaimsType, exp: std::time::Duration, ) -> Result<(), Error>; + + fn jwt_encoding_key(&self) -> Arc; + + fn algorithm(&self) -> Algorithm; +} + +#[derive(Clone)] +pub struct SessionStorage(Arc>); + +impl std::ops::Deref for SessionStorage { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl SessionStorage { + pub async fn set_by_jti( + &self, + claims: ClaimsType, + exp: std::time::Duration, + ) -> Result<(), Error> { + self.0.clone().set_by_jti(claims, exp).await + } + + pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result { + self.0.clone().get_from_jti(jti).await + } + + pub async fn store( + &self, + claims: ClaimsType, + exp: std::time::Duration, + ) -> Result, Error> { + self.set_by_jti(claims.clone(), exp).await?; + Ok(Authenticated { + claims: Arc::new(claims), + jwt_encoding_key: self.0.jwt_encoding_key(), + algorithm: self.algorithm(), + }) + } } struct Extractor; @@ -99,7 +145,7 @@ impl Extractor { jwt_encoding_key: Arc, jwt_decoding_key: Arc, algorithm: Algorithm, - storage: Arc>, + storage: SessionStorage, ) -> Result<(), Error> { let Some(authorisation_header) = req .headers() @@ -111,13 +157,17 @@ impl Extractor { .to_str() .map_err(|_| Error::NoAuthHeader)?; - let decoded_claims = decode::(as_str, &*jwt_decoding_key, &Validation::new(algorithm)) - .map_err(|_e| { - // let error_message = e.to_string(); - Error::InvalidSession - })?; + let decoded_claims = + decode::(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err( + |_e| { + // let error_message = e.to_string(); + Error::InvalidSession + }, + )?; let stored = storage + .0 + .clone() .get_from_jti(decoded_claims.claims.jti()) .await .map_err(|_| Error::InvalidSession)?; @@ -126,10 +176,11 @@ impl Extractor { return Err(Error::InvalidSession); } - req.extensions_mut() - .insert(Authenticated { - claims: Arc::new(decoded_claims.claims), - }); + req.extensions_mut().insert(Authenticated { + claims: Arc::new(decoded_claims.claims), + jwt_encoding_key, + algorithm, + }); Ok(()) } } diff --git a/crates/actix-jwt-session/src/redis_adapter.rs b/crates/actix-jwt-session/src/redis_adapter.rs index 4e2e40e..249f504 100644 --- a/crates/actix-jwt-session/src/redis_adapter.rs +++ b/crates/actix-jwt-session/src/redis_adapter.rs @@ -1,7 +1,6 @@ use super::*; use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; use futures_util::future::LocalBoxFuture; -use jsonwebtoken::{DecodingKey, Validation}; use redis::AsyncCommands; use std::future::{ready, Ready}; use std::marker::PhantomData; @@ -9,15 +8,23 @@ use std::rc::Rc; use std::sync::Arc; #[derive(Clone)] -pub struct RedisStorage { +struct RedisStorage { pool: redis_async_pool::RedisPool, + jwt_encoding_key: Arc, + algorithm: Algorithm, _claims_type_marker: PhantomData, } impl RedisStorage { - pub fn new(pool: redis_async_pool::RedisPool) -> Self { + pub fn new( + pool: redis_async_pool::RedisPool, + jwt_encoding_key: Arc, + algorithm: Algorithm, + ) -> Self { Self { pool, + jwt_encoding_key, + algorithm, _claims_type_marker: Default::default(), } } @@ -53,6 +60,14 @@ where .map_err(|_| Error::WriteFailed)?; Ok(()) } + + fn jwt_encoding_key(&self) -> Arc { + self.jwt_encoding_key.clone() + } + + fn algorithm(&self) -> Algorithm { + self.algorithm + } } pub struct RedisMiddleware @@ -61,10 +76,10 @@ where { _claims_type_marker: std::marker::PhantomData, service: Rc, - jwt_encoding_key: Arc, - jwt_decoding_key: Arc, - algorithm: Algorithm, - storage: Arc>, + jwt_encoding_key: Arc, + jwt_decoding_key: Arc, + algorithm: Algorithm, + storage: SessionStorage, } impl Service for RedisMiddleware @@ -88,7 +103,14 @@ where let storage = self.storage.clone(); async move { - Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?; + Extractor::extract_bearer_jwt( + &req, + jwt_encoding_key, + jwt_decoding_key, + algorithm, + storage, + ) + .await?; let res = svc.call(req).await?; Ok(res) } @@ -98,10 +120,10 @@ where #[derive(Clone)] pub struct RedisMiddlewareFactory { - jwt_encoding_key: Arc, - jwt_decoding_key: Arc, - algorithm: Algorithm, - storage: Arc>, + jwt_encoding_key: Arc, + jwt_decoding_key: Arc, + algorithm: Algorithm, + storage: SessionStorage, _claims_type_marker: PhantomData, } @@ -112,15 +134,20 @@ impl RedisMiddlewareFactory { algorithm: Algorithm, pool: redis_async_pool::RedisPool, ) -> Self { + let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm)); + Self { jwt_encoding_key, jwt_decoding_key, algorithm, - storage: RedisStorage::new( - pool), + storage: SessionStorage(storage), _claims_type_marker: Default::default(), } } + + pub fn storage(&self) -> SessionStorage { + self.storage.clone() + } } impl Transform for RedisMiddlewareFactory @@ -138,8 +165,9 @@ where ready(Ok(RedisMiddleware { service: Rc::new(service), storage: self.storage.clone(), + jwt_encoding_key: self.jwt_encoding_key.clone(), jwt_decoding_key: self.jwt_decoding_key.clone(), - jwt_validator: self.jwt_validator.clone(), + algorithm: self.algorithm, _claims_type_marker: PhantomData, })) } diff --git a/crates/actix-jwt-session/tests/ensure_redis_flow.rs b/crates/actix-jwt-session/tests/ensure_redis_flow.rs index a78cf62..36803e5 100644 --- a/crates/actix-jwt-session/tests/ensure_redis_flow.rs +++ b/crates/actix-jwt-session/tests/ensure_redis_flow.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage}; -use actix_web::get; +use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage, TokenStorage}; use actix_web::http::StatusCode; -use actix_web::web::Data; +use actix_web::web::{Data, Json}; use actix_web::HttpResponse; +use actix_web::{get, post}; use actix_web::{http::header::ContentType, test, App}; use jsonwebtoken::*; use ring::rand::SystemRandom; @@ -25,8 +25,6 @@ impl actix_jwt_session::Claims for Claims { #[tokio::test(flavor = "multi_thread")] async fn not_authenticated() { - const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; - let validator = Validation::new(JWT_SIGNING_ALGO); let redis = { use redis_async_pool::{RedisConnectionManager, RedisPool}; RedisPool::new( @@ -41,12 +39,14 @@ async fn not_authenticated() { let keys = JwtSigningKeys::generate().unwrap(); let factory = RedisMiddlewareFactory::::new( + Arc::new(keys.encoding_key), Arc::new(keys.decoding_key), - Arc::new(validator), - Arc::new(RedisStorage::new(redis.clone())), + Algorithm::EdDSA, + redis.clone(), ); let app = App::new() + .app_data(factory.storage()) .wrap(factory.clone()) .app_data(Data::new(redis.clone())) .service(sign_in) @@ -56,26 +56,56 @@ async fn not_authenticated() { let app = actix_web::test::init_service(app).await; - let res = test::call_service(&app, test::TestRequest::default() - .insert_header(ContentType::plaintext()) - .to_request()).await; + let res = test::call_service( + &app, + test::TestRequest::default() + .insert_header(ContentType::plaintext()) + .to_request(), + ) + .await; assert!(res.status().is_success()); - let res = test::call_service(&app, test::TestRequest::default() - .uri("/s") - .insert_header(ContentType::plaintext()) - .to_request()).await; - let s = StatusCode::UNAUTHORIZED; - assert_eq!(res.status(), s); + let res = test::call_service( + &app, + test::TestRequest::default() + .uri("/s") + .insert_header(ContentType::plaintext()) + .to_request(), + ) + .await; + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let origina_claims = Claims { id: Uuid::new_v4() }; + let res = test::call_service( + &app, + test::TestRequest::default() + .uri("/in") + .method(actix_web::http::Method::POST) + .insert_header(ContentType::json()) + .set_json(&origina_claims) + .to_request(), + ) + .await; + assert_eq!(res.status(), StatusCode::OK); } -#[get("/in")] -async fn sign_in(store: Data>) -> HttpResponse { - HttpResponse::Ok().body("") +#[post("/in")] +async fn sign_in( + store: Data>, + claims: Json, +) -> Result { + let claims = claims.into_inner(); + let store = store.into_inner(); + store + .clone() + .set_by_jti(claims, std::time::Duration::from_secs(300)) + .await + .unwrap(); + Ok(HttpResponse::Ok().body("")) } -#[get("/out")] -async fn sign_out(store: Data>) -> HttpResponse { +#[post("/out")] +async fn sign_out(_store: Data>) -> HttpResponse { HttpResponse::Ok().body("") } diff --git a/crates/oswilno-server/src/main.rs b/crates/oswilno-server/src/main.rs index 2a6a394..250da1d 100644 --- a/crates/oswilno-server/src/main.rs +++ b/crates/oswilno-server/src/main.rs @@ -48,9 +48,11 @@ async fn main() -> std::io::Result<()> { HttpServer::new(move || { let session_config = session_config.clone(); + let session_factory = session_config.factory(); App::new() .wrap(middleware::Logger::default()) - .wrap(session_config.factory()) + .app_data(Data::new(session_factory.storage())) + .wrap(session_factory) .app_data(Data::new(conn.clone())) .app_data(Data::new(redis.clone())) .app_data(Data::new(l10n.clone())) diff --git a/crates/oswilno-session/src/lib.rs b/crates/oswilno-session/src/lib.rs index b09d4ba..6053a28 100644 --- a/crates/oswilno-session/src/lib.rs +++ b/crates/oswilno-session/src/lib.rs @@ -1,6 +1,7 @@ use std::ops::Add; use std::sync::Arc; +use actix_jwt_session::SessionStorage; pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory}; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpResponse}; @@ -13,7 +14,6 @@ use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; -use time::ext::*; use time::OffsetDateTime; mod extract_session; @@ -24,7 +24,7 @@ pub use oswilno_view::filters; pub type UserSession = Claims; #[derive(Clone, Copy)] -pub struct JWTTtl(time::Duration); +pub struct JWTTtl(std::time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[serde(rename_all = "snake_case")] @@ -73,19 +73,15 @@ pub struct LoginResponse { claims: Claims, } -const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; - #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, - encoding_key: Data, factory: RedisMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config - .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(login_view) @@ -130,17 +126,16 @@ impl SessionConfigurator { } pub fn new(redis: redis_async_pool::RedisPool) -> Self { - let jwt_ttl = JWTTtl(31.days()); + let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60)); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); - let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_factory = RedisMiddlewareFactory::::new( + Arc::new(jwt_signing_keys.encoding_key), Arc::new(jwt_signing_keys.decoding_key), - Arc::new(validator), - Arc::new(actix_jwt_session::RedisStorage::new(redis)), + Algorithm::EdDSA, + redis, ); Self { - encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } @@ -186,18 +181,16 @@ async fn login_partial_view(t: Data) -> SignInPartialTemplat #[autometrics] #[post("/login")] async fn login( - jwt_encoding_key: Data, jwt_ttl: Data, db: Data, + redis: Data>, payload: Form, t: Data, lang: Lang, - redis: Data, ) -> Result { let t = t.into_inner(); let mut errors = Errors::default(); match login_inner( - jwt_encoding_key, jwt_ttl.into_inner(), payload.into_inner(), db.into_inner(), @@ -221,11 +214,10 @@ async fn login( } async fn login_inner( - jwt_encoding_key: Data, jwt_ttl: Arc, payload: SignInPayload, db: Arc, - redis: Arc, + redis: Arc>, errors: &mut Errors, ) -> Result { let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; @@ -261,42 +253,24 @@ async fn login_inner( audience: Audience::Web, jwt_id: uuid::Uuid::new_v4(), }; - let jwt_token = encode( - &Header::new(JWT_SIGNING_ALGO), - &jwt_claims, - &jwt_encoding_key, - ) - .map_err(|_| { - errors.push_global("Bad credentials"); - payload.clone() - })?; - let Ok(bin_value) = bincode::serialize(&jwt_claims) else { - errors.push_global("Failed to sign in. Please try later"); - return Err(payload.clone()); - }; - { - use redis::AsyncCommands; - - let Ok(mut conn) = redis.get().await else { - errors.push_global("Failed to sign in. Please try later"); - return Err(payload); - }; - - if let Err(e) = conn - .set_ex::<'_, _, _, String>( - jwt_claims.jwt_id.as_bytes(), - bin_value, - jwt_ttl.0.as_seconds_f32() as usize, - ) - .await - { + let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await { + Err(e) => { tracing::warn!("Failed to set sign-in claims in redis: {e}"); errors.push_global("Failed to sign in. Please try later"); return Err(payload); } - } + Ok(jwt_token) => jwt_token, + }; + let bearer_token = match jwt_token.encode() { + Ok(token) => token, + Err(e) => { + tracing::warn!("Failed to set sign-in claims in redis: {e}"); + errors.push_global("Failed to sign in. Please try later"); + return Err(payload); + } + }; Ok(LoginResponse { - bearer_token: jwt_token, + bearer_token, claims: jwt_claims, }) } diff --git a/crates/oswilno-view/src/lib.rs b/crates/oswilno-view/src/lib.rs index 988df62..5e00845 100644 --- a/crates/oswilno-view/src/lib.rs +++ b/crates/oswilno-view/src/lib.rs @@ -8,7 +8,7 @@ pub mod lang; #[derive(Debug, askama_actix::Template)] #[template(path = "../templates/base.html")] -pub struct Layout { +pub struct Layout { pub main: BodyTemplate, } diff --git a/crates/oswilno-view/templates/base.html b/crates/oswilno-view/templates/base.html index bfbbfd3..f8c1718 100644 --- a/crates/oswilno-view/templates/base.html +++ b/crates/oswilno-view/templates/base.html @@ -11,6 +11,8 @@ -
{{ main }}
+
+ {{ main|safe }} +