Add refresh toke

This commit is contained in:
Adrian Woźniak 2023-08-24 16:26:10 +02:00
parent 877dcae1a9
commit 58a0239a05
5 changed files with 69 additions and 34 deletions

View File

@ -127,14 +127,18 @@ use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::Deserialize;
use serde::{de::DeserializeOwned, Serialize};
use std::borrow::Cow;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::SystemTime;
use uuid::Uuid;
use std::borrow::Cow;
/// Default authorization header is "Authorization"
pub static DEFAULT_HEADER_NAME: &str = "Authorization";
pub static JWT_HEADER_NAME: &str = "Authorization";
pub static REFRESH_HEADER_NAME: &str = "X-Refresh";
pub static JWT_COOKIE_NAME: &str = "ACX-Auth";
pub static REFRESH_COOKIE_NAME: &str = "ACX-Refresh";
pub static REFRESH_PARAM_NAME: &str = "refresh_token";
/// Serializable and storable struct which represent JWT claims
///
@ -666,7 +670,9 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into)
req.cookie(self.cookie_name)
.map(|c| c.value().to_string())
.map(Into::into)
}
}
@ -693,9 +699,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req
.headers()
.get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into)
req.headers()
.get(self.header_name)
.and_then(|h| h.to_str().ok())
.map(Into::into)
}
}

View File

@ -153,17 +153,20 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
algorithm: Algorithm,
pool: redis_async_pool::RedisPool,
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
) -> Self {
) -> (SessionStorage<ClaimsType>, Self) {
let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
Self {
jwt_encoding_key: jwt_encoding_key.clone(),
jwt_decoding_key,
algorithm,
storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm),
extractors: Arc::new(extractors),
_claims_type_marker: Default::default(),
}
let storage = SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm);
(
storage.clone(),
Self {
jwt_encoding_key: jwt_encoding_key.clone(),
jwt_decoding_key,
algorithm,
storage,
extractors: Arc::new(extractors),
_claims_type_marker: Default::default(),
},
)
}
pub fn storage(&self) -> SessionStorage<ClaimsType> {

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use actix_jwt_session::{
Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, DEFAULT_HEADER_NAME,
Authenticated, HeaderExtractor, RedisMiddlewareFactory, SessionStorage, JWT_HEADER_NAME,
};
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
@ -49,7 +49,7 @@ async fn not_authenticated() {
Arc::new(keys.decoding_key),
Algorithm::EdDSA,
redis.clone(),
vec![Box::new(HeaderExtractor::new(DEFAULT_HEADER_NAME))],
vec![Box::new(HeaderExtractor::new(JWT_HEADER_NAME))],
);
let app = App::new()

View File

@ -22,7 +22,7 @@ async fn main() -> std::io::Result<()> {
.init();
}
let conn: sea_orm::DatabaseConnection = {
let pq: sea_orm::DatabaseConnection = {
let mut db_opts = ConnectOptions::new("postgres://postgres@localhost/oswilno".to_string());
db_opts
.max_connections(100)
@ -51,17 +51,16 @@ async fn main() -> std::io::Result<()> {
oswilno_parking_space::translations(&mut l10n);
oswilno_parking_space::init(Arc::new(conn.clone())).await;
oswilno_parking_space::init(Arc::new(pq.clone())).await;
HttpServer::new(move || {
let session_config = session_config.clone();
let session_factory = session_config.factory();
App::new()
.wrap(middleware::Logger::default())
.app_data(Data::new(session_factory.storage()))
.wrap(session_factory)
.app_data(Data::new(conn.clone()))
.app_data(Data::new(redis.clone()))
.app_data(Data::new(pq.clone()))
// .app_data(Data::new(redis.clone()))
.app_data(Data::new(l10n.clone()))
.configure(oswilno_parking_space::mount)
.configure(oswilno_admin::mount)

View File

@ -2,7 +2,7 @@ use std::io::Read;
use std::ops::Add;
use std::sync::Arc;
use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, DEFAULT_HEADER_NAME};
use actix_jwt_session::{CookieExtractor, HeaderExtractor, SessionStorage, JWT_HEADER_NAME};
pub use actix_jwt_session::{Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpRequest, HttpResponse};
@ -27,6 +27,9 @@ pub type MaybeAuthenticated = actix_jwt_session::MaybeAuthenticated<Claims>;
#[derive(Clone, Copy)]
pub struct JWTTtl(std::time::Duration);
#[derive(Clone, Copy)]
pub struct RefreshTtl(std::time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum Audience {
@ -79,13 +82,17 @@ pub struct LoginResponse {
#[derive(Clone)]
pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>,
refresh_ttl: Data<RefreshTtl>,
factory: RedisMiddlewareFactory<Claims>,
session_storage: SessionStorage<Claims>,
}
impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) {
config
.app_data(self.session_storage)
.app_data(self.jwt_ttl)
.app_data(self.refresh_ttl)
.service(login)
.service(login_view)
.service(logout)
@ -122,24 +129,27 @@ impl SessionConfigurator {
pub fn new(redis: redis_async_pool::RedisPool) -> Self {
let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
let refresh_ttl = RefreshTtl(std::time::Duration::from_secs(6 * 31 * 60 * 60));
std::fs::create_dir_all("./config").ok();
let jwt_signing_keys = JwtSigningKeys::load_or_create();
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
let (session_storage, auth_middleware_factory) = RedisMiddlewareFactory::<Claims>::new(
Arc::new(jwt_signing_keys.encoding_key),
Arc::new(jwt_signing_keys.decoding_key),
Algorithm::EdDSA,
redis,
vec![
Box::new(CookieExtractor::<Claims>::new(DEFAULT_HEADER_NAME)),
Box::new(HeaderExtractor::<Claims>::new(DEFAULT_HEADER_NAME)),
Box::new(CookieExtractor::<Claims>::new(JWT_HEADER_NAME)),
Box::new(HeaderExtractor::<Claims>::new(JWT_HEADER_NAME)),
],
);
Self {
jwt_ttl: Data::new(jwt_ttl.clone()),
jwt_ttl: Data::new(jwt_ttl),
refresh_ttl: Data::new(refresh_ttl),
factory: auth_middleware_factory,
session_storage,
}
}
}
@ -204,6 +214,7 @@ async fn login_view(req: HttpRequest, t: Data<TranslationStorage>) -> HttpRespon
#[post("/login")]
async fn login(
jwt_ttl: Data<JWTTtl>,
refresh_ttl: Data<RefreshTtl>,
db: Data<DatabaseConnection>,
redis: Data<SessionStorage<Claims>>,
payload: Form<SignInPayload>,
@ -214,6 +225,7 @@ async fn login(
let mut errors = Errors::default();
match login_inner(
jwt_ttl.into_inner(),
refresh_ttl.into_inner(),
payload.into_inner(),
db.into_inner(),
redis.into_inner(),
@ -237,6 +249,7 @@ async fn login(
async fn login_inner(
jwt_ttl: Arc<JWTTtl>,
refresh_ttl: Arc<RefreshTtl>,
payload: SignInPayload,
db: Arc<DatabaseConnection>,
redis: Arc<SessionStorage<Claims>>,
@ -280,7 +293,7 @@ async fn login_inner(
jwt_id: uuid::Uuid::new_v4(),
account_id: account.id,
};
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await {
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0, refresh_ttl.0).await {
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
@ -288,7 +301,15 @@ async fn login_inner(
}
Ok(jwt_token) => jwt_token,
};
let bearer_token = match jwt_token.encode() {
let bearer_token = match jwt_token.jwt.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to encode claims: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
};
let refresh_token = match jwt_token.refresh.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to encode claims: {e}");
@ -297,18 +318,23 @@ async fn login_inner(
}
};
let cookie =
actix_web::cookie::Cookie::build(actix_jwt_session::DEFAULT_HEADER_NAME, &bearer_token)
let jwt_cookie =
actix_web::cookie::Cookie::build(actix_jwt_session::JWT_COOKIE_NAME, &bearer_token)
.http_only(true)
.finish();
let refresh_cookie =
actix_web::cookie::Cookie::build(actix_jwt_session::REFRESH_COOKIE_NAME, &refresh_token)
.http_only(true)
.finish();
Ok(HttpResponse::SeeOther()
.append_header((
actix_jwt_session::DEFAULT_HEADER_NAME,
actix_jwt_session::JWT_HEADER_NAME,
format!("Bearer {bearer_token}").as_str(),
))
.append_header(("Location", "/"))
.append_header(("HX-Redirect", "/"))
.cookie(cookie)
.cookie(jwt_cookie)
.cookie(refresh_cookie)
.body(""))
}