From 58a0239a05ccbb86e54eed2ca56d15adc69c9cbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20Wo=C5=BAniak?= Date: Thu, 24 Aug 2023 16:26:10 +0200 Subject: [PATCH] Add refresh toke --- crates/actix-jwt-session/src/lib.rs | 19 +++++--- crates/actix-jwt-session/src/redis_adapter.rs | 23 +++++---- .../tests/ensure_redis_flow.rs | 4 +- crates/oswilno-server/src/main.rs | 9 ++-- crates/oswilno-session/src/lib.rs | 48 ++++++++++++++----- 5 files changed, 69 insertions(+), 34 deletions(-) diff --git a/crates/actix-jwt-session/src/lib.rs b/crates/actix-jwt-session/src/lib.rs index b2c43c6..e5147cb 100644 --- a/crates/actix-jwt-session/src/lib.rs +++ b/crates/actix-jwt-session/src/lib.rs @@ -127,14 +127,18 @@ use async_trait::async_trait; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use serde::Deserialize; use serde::{de::DeserializeOwned, Serialize}; +use std::borrow::Cow; use std::marker::PhantomData; use std::sync::Arc; use std::time::SystemTime; use uuid::Uuid; -use std::borrow::Cow; /// Default authorization header is "Authorization" -pub static DEFAULT_HEADER_NAME: &str = "Authorization"; +pub static JWT_HEADER_NAME: &str = "Authorization"; +pub static REFRESH_HEADER_NAME: &str = "X-Refresh"; +pub static JWT_COOKIE_NAME: &str = "ACX-Auth"; +pub static REFRESH_COOKIE_NAME: &str = "ACX-Refresh"; +pub static REFRESH_PARAM_NAME: &str = "refresh_token"; /// Serializable and storable struct which represent JWT claims /// @@ -666,7 +670,9 @@ impl CookieExtractor { #[async_trait(?Send)] impl SessionExtractor for CookieExtractor { async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option> { - req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into) + req.cookie(self.cookie_name) + .map(|c| c.value().to_string()) + .map(Into::into) } } @@ -693,9 +699,10 @@ impl HeaderExtractor { #[async_trait(?Send)] impl SessionExtractor for HeaderExtractor { async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option> { - req - .headers() - .get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into) + req.headers() + .get(self.header_name) + .and_then(|h| h.to_str().ok()) + .map(Into::into) } } diff --git a/crates/actix-jwt-session/src/redis_adapter.rs b/crates/actix-jwt-session/src/redis_adapter.rs index 53f5197..b45467d 100644 --- a/crates/actix-jwt-session/src/redis_adapter.rs +++ b/crates/actix-jwt-session/src/redis_adapter.rs @@ -153,17 +153,20 @@ impl RedisMiddlewareFactory { algorithm: Algorithm, pool: redis_async_pool::RedisPool, extractors: Vec>>, - ) -> Self { + ) -> (SessionStorage, Self) { let storage = Arc::new(RedisStorage::::new(pool)); - - Self { - jwt_encoding_key: jwt_encoding_key.clone(), - jwt_decoding_key, - algorithm, - storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm), - extractors: Arc::new(extractors), - _claims_type_marker: Default::default(), - } + let storage = SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm); + ( + storage.clone(), + Self { + jwt_encoding_key: jwt_encoding_key.clone(), + jwt_decoding_key, + algorithm, + storage, + extractors: Arc::new(extractors), + _claims_type_marker: Default::default(), + }, + ) } pub fn storage(&self) -> SessionStorage { diff --git a/crates/actix-jwt-session/tests/ensure_redis_flow.rs b/crates/actix-jwt-session/tests/ensure_redis_flow.rs index e8734c8..03f392e 100644 --- a/crates/actix-jwt-session/tests/ensure_redis_flow.rs +++ b/crates/actix-jwt-session/tests/ensure_redis_flow.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use actix_jwt_session::{ - Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, DEFAULT_HEADER_NAME, + Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, JWT_HEADER_NAME, }; use actix_web::http::StatusCode; use actix_web::web::{Data, Json}; @@ -49,7 +49,7 @@ async fn not_authenticated() { Arc::new(keys.decoding_key), Algorithm::EdDSA, redis.clone(), - vec![Box::new(HeaderExtractor::new(DEFAULT_HEADER_NAME))], + vec![Box::new(HeaderExtractor::new(JWT_HEADER_NAME))], ); let app = App::new() diff --git a/crates/oswilno-server/src/main.rs b/crates/oswilno-server/src/main.rs index 5ea778b..c3f80e9 100644 --- a/crates/oswilno-server/src/main.rs +++ b/crates/oswilno-server/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> std::io::Result<()> { .init(); } - let conn: sea_orm::DatabaseConnection = { + let pq: sea_orm::DatabaseConnection = { let mut db_opts = ConnectOptions::new("postgres://postgres@localhost/oswilno".to_string()); db_opts .max_connections(100) @@ -51,17 +51,16 @@ async fn main() -> std::io::Result<()> { oswilno_parking_space::translations(&mut l10n); - oswilno_parking_space::init(Arc::new(conn.clone())).await; + oswilno_parking_space::init(Arc::new(pq.clone())).await; HttpServer::new(move || { let session_config = session_config.clone(); let session_factory = session_config.factory(); App::new() .wrap(middleware::Logger::default()) - .app_data(Data::new(session_factory.storage())) .wrap(session_factory) - .app_data(Data::new(conn.clone())) - .app_data(Data::new(redis.clone())) + .app_data(Data::new(pq.clone())) + // .app_data(Data::new(redis.clone())) .app_data(Data::new(l10n.clone())) .configure(oswilno_parking_space::mount) .configure(oswilno_admin::mount) diff --git a/crates/oswilno-session/src/lib.rs b/crates/oswilno-session/src/lib.rs index 34e2ab7..94fd68d 100644 --- a/crates/oswilno-session/src/lib.rs +++ b/crates/oswilno-session/src/lib.rs @@ -2,7 +2,7 @@ use std::io::Read; use std::ops::Add; use std::sync::Arc; -use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, DEFAULT_HEADER_NAME}; +use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, JWT_HEADER_NAME}; pub use actix_jwt_session::{Error, RedisMiddlewareFactory}; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpRequest, HttpResponse}; @@ -27,6 +27,9 @@ pub type MaybeAuthenticated = actix_jwt_session::MaybeAuthenticated; #[derive(Clone, Copy)] pub struct JWTTtl(std::time::Duration); +#[derive(Clone, Copy)] +pub struct RefreshTtl(std::time::Duration); + #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[serde(rename_all = "snake_case")] pub enum Audience { @@ -79,13 +82,17 @@ pub struct LoginResponse { #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, + refresh_ttl: Data, factory: RedisMiddlewareFactory, + session_storage: SessionStorage, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config + .app_data(self.session_storage) .app_data(self.jwt_ttl) + .app_data(self.refresh_ttl) .service(login) .service(login_view) .service(logout) @@ -122,24 +129,27 @@ impl SessionConfigurator { pub fn new(redis: redis_async_pool::RedisPool) -> Self { let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60)); + let refresh_ttl = RefreshTtl(std::time::Duration::from_secs(6 * 31 * 60 * 60)); std::fs::create_dir_all("./config").ok(); let jwt_signing_keys = JwtSigningKeys::load_or_create(); - let auth_middleware_factory = RedisMiddlewareFactory::::new( + let (session_storage, auth_middleware_factory) = RedisMiddlewareFactory::::new( Arc::new(jwt_signing_keys.encoding_key), Arc::new(jwt_signing_keys.decoding_key), Algorithm::EdDSA, redis, vec![ - Box::new(CookieExtractor::::new(DEFAULT_HEADER_NAME)), - Box::new(HeaderExtractor::::new(DEFAULT_HEADER_NAME)), + Box::new(CookieExtractor::::new(JWT_HEADER_NAME)), + Box::new(HeaderExtractor::::new(JWT_HEADER_NAME)), ], ); Self { - jwt_ttl: Data::new(jwt_ttl.clone()), + jwt_ttl: Data::new(jwt_ttl), + refresh_ttl: Data::new(refresh_ttl), factory: auth_middleware_factory, + session_storage, } } } @@ -204,6 +214,7 @@ async fn login_view(req: HttpRequest, t: Data) -> HttpRespon #[post("/login")] async fn login( jwt_ttl: Data, + refresh_ttl: Data, db: Data, redis: Data>, payload: Form, @@ -214,6 +225,7 @@ async fn login( let mut errors = Errors::default(); match login_inner( jwt_ttl.into_inner(), + refresh_ttl.into_inner(), payload.into_inner(), db.into_inner(), redis.into_inner(), @@ -237,6 +249,7 @@ async fn login( async fn login_inner( jwt_ttl: Arc, + refresh_ttl: Arc, payload: SignInPayload, db: Arc, redis: Arc>, @@ -280,7 +293,7 @@ async fn login_inner( jwt_id: uuid::Uuid::new_v4(), account_id: account.id, }; - let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await { + let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0, refresh_ttl.0).await { Err(e) => { tracing::warn!("Failed to set sign-in claims in redis: {e}"); errors.push_global("Failed to sign in. Please try later"); @@ -288,7 +301,15 @@ async fn login_inner( } Ok(jwt_token) => jwt_token, }; - let bearer_token = match jwt_token.encode() { + let bearer_token = match jwt_token.jwt.encode() { + Ok(token) => token, + Err(e) => { + tracing::warn!("Failed to encode claims: {e}"); + errors.push_global("Failed to sign in. Please try later"); + return Err(payload); + } + }; + let refresh_token = match jwt_token.refresh.encode() { Ok(token) => token, Err(e) => { tracing::warn!("Failed to encode claims: {e}"); @@ -297,18 +318,23 @@ async fn login_inner( } }; - let cookie = - actix_web::cookie::Cookie::build(actix_jwt_session::DEFAULT_HEADER_NAME, &bearer_token) + let jwt_cookie = + actix_web::cookie::Cookie::build(actix_jwt_session::JWT_COOKIE_NAME, &bearer_token) + .http_only(true) + .finish(); + let refresh_cookie = + actix_web::cookie::Cookie::build(actix_jwt_session::REFRESH_COOKIE_NAME, &refresh_token) .http_only(true) .finish(); Ok(HttpResponse::SeeOther() .append_header(( - actix_jwt_session::DEFAULT_HEADER_NAME, + actix_jwt_session::JWT_HEADER_NAME, format!("Bearer {bearer_token}").as_str(), )) .append_header(("Location", "/")) .append_header(("HX-Redirect", "/")) - .cookie(cookie) + .cookie(jwt_cookie) + .cookie(refresh_cookie) .body("")) }