Add refresh toke

This commit is contained in:
Adrian Woźniak 2023-08-24 16:26:10 +02:00
parent 877dcae1a9
commit 58a0239a05
5 changed files with 69 additions and 34 deletions

View File

@ -127,14 +127,18 @@ use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::Deserialize; use serde::Deserialize;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::borrow::Cow;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use uuid::Uuid; use uuid::Uuid;
use std::borrow::Cow;
/// Default authorization header is "Authorization" /// Default authorization header is "Authorization"
pub static DEFAULT_HEADER_NAME: &str = "Authorization"; pub static JWT_HEADER_NAME: &str = "Authorization";
pub static REFRESH_HEADER_NAME: &str = "X-Refresh";
pub static JWT_COOKIE_NAME: &str = "ACX-Auth";
pub static REFRESH_COOKIE_NAME: &str = "ACX-Refresh";
pub static REFRESH_PARAM_NAME: &str = "refresh_token";
/// Serializable and storable struct which represent JWT claims /// Serializable and storable struct which represent JWT claims
/// ///
@ -666,7 +670,9 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
#[async_trait(?Send)] #[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> { impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> { async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into) req.cookie(self.cookie_name)
.map(|c| c.value().to_string())
.map(Into::into)
} }
} }
@ -693,9 +699,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
#[async_trait(?Send)] #[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> { impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> { async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req req.headers()
.headers() .get(self.header_name)
.get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into) .and_then(|h| h.to_str().ok())
.map(Into::into)
} }
} }

View File

@ -153,17 +153,20 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
algorithm: Algorithm, algorithm: Algorithm,
pool: redis_async_pool::RedisPool, pool: redis_async_pool::RedisPool,
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>, extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
) -> Self { ) -> (SessionStorage<ClaimsType>, Self) {
let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool)); let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
let storage = SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm);
Self { (
jwt_encoding_key: jwt_encoding_key.clone(), storage.clone(),
jwt_decoding_key, Self {
algorithm, jwt_encoding_key: jwt_encoding_key.clone(),
storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm), jwt_decoding_key,
extractors: Arc::new(extractors), algorithm,
_claims_type_marker: Default::default(), storage,
} extractors: Arc::new(extractors),
_claims_type_marker: Default::default(),
},
)
} }
pub fn storage(&self) -> SessionStorage<ClaimsType> { pub fn storage(&self) -> SessionStorage<ClaimsType> {

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use actix_jwt_session::{ use actix_jwt_session::{
Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, DEFAULT_HEADER_NAME, Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, JWT_HEADER_NAME,
}; };
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use actix_web::web::{Data, Json}; use actix_web::web::{Data, Json};
@ -49,7 +49,7 @@ async fn not_authenticated() {
Arc::new(keys.decoding_key), Arc::new(keys.decoding_key),
Algorithm::EdDSA, Algorithm::EdDSA,
redis.clone(), redis.clone(),
vec![Box::new(HeaderExtractor::new(DEFAULT_HEADER_NAME))], vec![Box::new(HeaderExtractor::new(JWT_HEADER_NAME))],
); );
let app = App::new() let app = App::new()

View File

@ -22,7 +22,7 @@ async fn main() -> std::io::Result<()> {
.init(); .init();
} }
let conn: sea_orm::DatabaseConnection = { let pq: sea_orm::DatabaseConnection = {
let mut db_opts = ConnectOptions::new("postgres://postgres@localhost/oswilno".to_string()); let mut db_opts = ConnectOptions::new("postgres://postgres@localhost/oswilno".to_string());
db_opts db_opts
.max_connections(100) .max_connections(100)
@ -51,17 +51,16 @@ async fn main() -> std::io::Result<()> {
oswilno_parking_space::translations(&mut l10n); oswilno_parking_space::translations(&mut l10n);
oswilno_parking_space::init(Arc::new(conn.clone())).await; oswilno_parking_space::init(Arc::new(pq.clone())).await;
HttpServer::new(move || { HttpServer::new(move || {
let session_config = session_config.clone(); let session_config = session_config.clone();
let session_factory = session_config.factory(); let session_factory = session_config.factory();
App::new() App::new()
.wrap(middleware::Logger::default()) .wrap(middleware::Logger::default())
.app_data(Data::new(session_factory.storage()))
.wrap(session_factory) .wrap(session_factory)
.app_data(Data::new(conn.clone())) .app_data(Data::new(pq.clone()))
.app_data(Data::new(redis.clone())) // .app_data(Data::new(redis.clone()))
.app_data(Data::new(l10n.clone())) .app_data(Data::new(l10n.clone()))
.configure(oswilno_parking_space::mount) .configure(oswilno_parking_space::mount)
.configure(oswilno_admin::mount) .configure(oswilno_admin::mount)

View File

@ -2,7 +2,7 @@ use std::io::Read;
use std::ops::Add; use std::ops::Add;
use std::sync::Arc; use std::sync::Arc;
use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, DEFAULT_HEADER_NAME}; use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, JWT_HEADER_NAME};
pub use actix_jwt_session::{Error, RedisMiddlewareFactory}; pub use actix_jwt_session::{Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpRequest, HttpResponse}; use actix_web::{get, post, HttpRequest, HttpResponse};
@ -27,6 +27,9 @@ pub type MaybeAuthenticated = actix_jwt_session::MaybeAuthenticated<Claims>;
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct JWTTtl(std::time::Duration); pub struct JWTTtl(std::time::Duration);
#[derive(Clone, Copy)]
pub struct RefreshTtl(std::time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Audience { pub enum Audience {
@ -79,13 +82,17 @@ pub struct LoginResponse {
#[derive(Clone)] #[derive(Clone)]
pub struct SessionConfigurator { pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>, jwt_ttl: Data<JWTTtl>,
refresh_ttl: Data<RefreshTtl>,
factory: RedisMiddlewareFactory<Claims>, factory: RedisMiddlewareFactory<Claims>,
session_storage: SessionStorage<Claims>,
} }
impl SessionConfigurator { impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) { pub fn app_data(self, config: &mut ServiceConfig) {
config config
.app_data(self.session_storage)
.app_data(self.jwt_ttl) .app_data(self.jwt_ttl)
.app_data(self.refresh_ttl)
.service(login) .service(login)
.service(login_view) .service(login_view)
.service(logout) .service(logout)
@ -122,24 +129,27 @@ impl SessionConfigurator {
pub fn new(redis: redis_async_pool::RedisPool) -> Self { pub fn new(redis: redis_async_pool::RedisPool) -> Self {
let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60)); let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
let refresh_ttl = RefreshTtl(std::time::Duration::from_secs(6 * 31 * 60 * 60));
std::fs::create_dir_all("./config").ok(); std::fs::create_dir_all("./config").ok();
let jwt_signing_keys = JwtSigningKeys::load_or_create(); let jwt_signing_keys = JwtSigningKeys::load_or_create();
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new( let (session_storage, auth_middleware_factory) = RedisMiddlewareFactory::<Claims>::new(
Arc::new(jwt_signing_keys.encoding_key), Arc::new(jwt_signing_keys.encoding_key),
Arc::new(jwt_signing_keys.decoding_key), Arc::new(jwt_signing_keys.decoding_key),
Algorithm::EdDSA, Algorithm::EdDSA,
redis, redis,
vec![ vec![
Box::new(CookieExtractor::<Claims>::new(DEFAULT_HEADER_NAME)), Box::new(CookieExtractor::<Claims>::new(JWT_HEADER_NAME)),
Box::new(HeaderExtractor::<Claims>::new(DEFAULT_HEADER_NAME)), Box::new(HeaderExtractor::<Claims>::new(JWT_HEADER_NAME)),
], ],
); );
Self { Self {
jwt_ttl: Data::new(jwt_ttl.clone()), jwt_ttl: Data::new(jwt_ttl),
refresh_ttl: Data::new(refresh_ttl),
factory: auth_middleware_factory, factory: auth_middleware_factory,
session_storage,
} }
} }
} }
@ -204,6 +214,7 @@ async fn login_view(req: HttpRequest, t: Data<TranslationStorage>) -> HttpRespon
#[post("/login")] #[post("/login")]
async fn login( async fn login(
jwt_ttl: Data<JWTTtl>, jwt_ttl: Data<JWTTtl>,
refresh_ttl: Data<RefreshTtl>,
db: Data<DatabaseConnection>, db: Data<DatabaseConnection>,
redis: Data<SessionStorage<Claims>>, redis: Data<SessionStorage<Claims>>,
payload: Form<SignInPayload>, payload: Form<SignInPayload>,
@ -214,6 +225,7 @@ async fn login(
let mut errors = Errors::default(); let mut errors = Errors::default();
match login_inner( match login_inner(
jwt_ttl.into_inner(), jwt_ttl.into_inner(),
refresh_ttl.into_inner(),
payload.into_inner(), payload.into_inner(),
db.into_inner(), db.into_inner(),
redis.into_inner(), redis.into_inner(),
@ -237,6 +249,7 @@ async fn login(
async fn login_inner( async fn login_inner(
jwt_ttl: Arc<JWTTtl>, jwt_ttl: Arc<JWTTtl>,
refresh_ttl: Arc<RefreshTtl>,
payload: SignInPayload, payload: SignInPayload,
db: Arc<DatabaseConnection>, db: Arc<DatabaseConnection>,
redis: Arc<SessionStorage<Claims>>, redis: Arc<SessionStorage<Claims>>,
@ -280,7 +293,7 @@ async fn login_inner(
jwt_id: uuid::Uuid::new_v4(), jwt_id: uuid::Uuid::new_v4(),
account_id: account.id, account_id: account.id,
}; };
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await { let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0, refresh_ttl.0).await {
Err(e) => { Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}"); tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later"); errors.push_global("Failed to sign in. Please try later");
@ -288,7 +301,15 @@ async fn login_inner(
} }
Ok(jwt_token) => jwt_token, Ok(jwt_token) => jwt_token,
}; };
let bearer_token = match jwt_token.encode() { let bearer_token = match jwt_token.jwt.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to encode claims: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
};
let refresh_token = match jwt_token.refresh.encode() {
Ok(token) => token, Ok(token) => token,
Err(e) => { Err(e) => {
tracing::warn!("Failed to encode claims: {e}"); tracing::warn!("Failed to encode claims: {e}");
@ -297,18 +318,23 @@ async fn login_inner(
} }
}; };
let cookie = let jwt_cookie =
actix_web::cookie::Cookie::build(actix_jwt_session::DEFAULT_HEADER_NAME, &bearer_token) actix_web::cookie::Cookie::build(actix_jwt_session::JWT_COOKIE_NAME, &bearer_token)
.http_only(true)
.finish();
let refresh_cookie =
actix_web::cookie::Cookie::build(actix_jwt_session::REFRESH_COOKIE_NAME, &refresh_token)
.http_only(true) .http_only(true)
.finish(); .finish();
Ok(HttpResponse::SeeOther() Ok(HttpResponse::SeeOther()
.append_header(( .append_header((
actix_jwt_session::DEFAULT_HEADER_NAME, actix_jwt_session::JWT_HEADER_NAME,
format!("Bearer {bearer_token}").as_str(), format!("Bearer {bearer_token}").as_str(),
)) ))
.append_header(("Location", "/")) .append_header(("Location", "/"))
.append_header(("HX-Redirect", "/")) .append_header(("HX-Redirect", "/"))
.cookie(cookie) .cookie(jwt_cookie)
.cookie(refresh_cookie)
.body("")) .body(""))
} }