Improve session managment

This commit is contained in:
eraden 2023-08-23 22:52:17 +02:00
parent 4edccc7661
commit 38026a6f42
2 changed files with 210 additions and 95 deletions

View File

@ -125,10 +125,13 @@ use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::{FromRequest, HttpMessage};
use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::Deserialize;
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::SystemTime;
use uuid::Uuid;
use std::borrow::Cow;
/// Default authorization header is "Authorization"
pub static DEFAULT_HEADER_NAME: &str = "Authorization";
@ -142,6 +145,31 @@ pub trait Claims: PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync
fn subject(&self) -> &str;
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct RefreshToken {
pub access_jti: uuid::Uuid,
pub access_exp: std::time::Duration,
pub refresh_jti: uuid::Uuid,
pub refresh_exp: std::time::Duration,
pub iat: std::time::SystemTime,
}
impl Claims for RefreshToken {
fn jti(&self) -> uuid::Uuid {
self.refresh_jti
}
fn subject(&self) -> &str {
"refresh-token"
}
}
pub struct Pair<ClaimsType: Claims> {
pub jwt: Authenticated<ClaimsType>,
pub refresh: Authenticated<RefreshToken>,
}
/// Session related errors
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error {
@ -159,6 +187,8 @@ pub enum Error {
SerializeFailed,
#[error("Unable to write claims to storage")]
WriteFailed,
#[error("Access token expired")]
JWTExpired,
}
impl actix_web::ResponseError for Error {
@ -310,73 +340,193 @@ impl<T: Claims> FromRequest for MaybeAuthenticated<T> {
/// postgresql.
#[async_trait(?Send)]
pub trait TokenStorage: Send + Sync {
type ClaimsType: Claims;
/// Load claims from storage or returns [Error] if record does not exists or there was other
/// error while trying to fetch data from storage.
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error>;
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
async fn set_by_jti(
self: Arc<Self>,
claims: Self::ClaimsType,
jwt_jti: &[u8],
refresh_jti: &[u8],
bytes: &[u8],
exp: std::time::Duration,
) -> Result<(), Error>;
/// Erase claims from storage. You may ignore if claims does not exists in storage.
/// Redis implementation returns [Error::NotFound] if record does not exists.
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error>;
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error>;
}
/// Allow to save, read and remove session from storage.
#[derive(Clone)]
pub struct SessionStorage<ClaimsType: Claims> {
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
storage: Arc<dyn TokenStorage>,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
__ty: PhantomData<ClaimsType>,
}
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
type Target = Arc<dyn TokenStorage>;
fn deref(&self) -> &Self::Target {
&self.storage
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SessionRecord {
refresh_jti: uuid::Uuid,
jwt_jti: uuid::Uuid,
refresh_token: Vec<u8>,
jwt: Vec<u8>,
}
impl SessionRecord {
fn new<ClaimsType: Claims>(claims: ClaimsType, refresh: RefreshToken) -> Result<Self, Error> {
let refresh_jti = claims.jti();
let jwt_jti = refresh.refresh_jti;
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
let jwt = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
Ok(Self {
refresh_jti,
jwt_jti,
refresh_token,
jwt,
})
}
fn refresh_token(&self) -> Result<RefreshToken, Error> {
bincode::deserialize(&self.refresh_token).map_err(|_| Error::RecordMalformed)
}
fn jwt_token<ClaimsType: Claims>(&self) -> Result<ClaimsType, Error> {
bincode::deserialize(&self.jwt).map_err(|_| Error::RecordMalformed)
}
fn set_refresh_token(&mut self, refresh: RefreshToken) -> Result<(), Error> {
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
self.refresh_token = refresh_token;
Ok(())
}
}
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
pub async fn set_by_jti(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error> {
self.storage.clone().set_by_jti(claims, exp).await
pub fn new(
storage: Arc<dyn TokenStorage>,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
) -> Self {
Self {
storage,
jwt_encoding_key,
algorithm,
__ty: Default::default(),
}
}
/// Load claims from storage or returns [Error] if record does not exists or there was other
/// error while trying to fetch data from storage.
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
self.storage.clone().get_from_jti(jti).await
pub async fn find_jwt(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
let record = self.load_pair_by_jwt(jti).await?;
let refresh_token = record.refresh_token()?;
if refresh_token.iat + refresh_token.access_exp < std::time::SystemTime::now() {
return Err(Error::JWTExpired);
}
record.jwt_token()
}
pub async fn refresh(&self, refresh_jti: uuid::Uuid) -> Result<(), Error> {
let mut record = self.load_pair_by_refresh(refresh_jti).await?;
let mut refresh_token = record.refresh_token()?;
let exp = refresh_token.refresh_exp;
refresh_token.iat = SystemTime::now();
record.set_refresh_token(refresh_token)?;
self.store_pair(record, exp).await?;
Ok(())
}
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
pub async fn store(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<Authenticated<ClaimsType>, Error> {
self.set_by_jti(claims.clone(), exp).await?;
Ok(Authenticated {
jwt_exp: std::time::Duration,
refresh_exp: std::time::Duration,
) -> Result<Pair<ClaimsType>, Error> {
let refresh = RefreshToken {
refresh_jti: uuid::Uuid::new_v4(),
refresh_exp,
access_jti: claims.jti(),
access_exp: jwt_exp,
iat: std::time::SystemTime::now(),
};
let record = SessionRecord::new(claims.clone(), refresh.clone())?;
self.store_pair(record, refresh_exp).await?;
Ok(Pair {
jwt: Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
refresh: Authenticated {
claims: Arc::new(refresh),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
})
}
/// Erase claims from storage. You may ignore if claims does not exists in storage.
/// Redis implementation returns [Error::NotFound] if record does not exists.
pub async fn erase(&self, jti: Uuid) -> Result<(), Error> {
self.storage.clone().remove_by_jti(jti).await
let record = self.load_pair_by_jwt(jti).await?;
self.storage
.clone()
.remove_by_jti(record.refresh_jti.as_bytes())
.await?;
self.storage
.clone()
.remove_by_jti(record.jwt_jti.as_bytes())
.await?;
Ok(())
}
async fn store_pair(
&self,
record: SessionRecord,
exp: std::time::Duration,
) -> Result<(), Error> {
let value = bincode::serialize(&record).map_err(|_| Error::SerializeFailed)?;
self.storage
.clone()
.set_by_jti(
record.jwt_jti.as_bytes(),
record.refresh_jti.as_bytes(),
&value,
exp,
)
.await?;
Ok(())
}
async fn load_pair_by_jwt(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
}
async fn load_pair_by_refresh(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
}
}
@ -440,7 +590,19 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error>;
) -> Result<(), Error> {
let Some(as_str) = self.extract_token_text(req).await else {
return Ok(());
};
let decoded_claims = self.decode(&*as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
}
/// Decode encrypted JWT to structure
fn decode(
@ -468,7 +630,7 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
) -> Result<(), Error> {
let stored = storage
.clone()
.get_from_jti(claims.jti())
.find_jwt(claims.jti())
.await
.map_err(|_| Error::InvalidSession)?;
@ -477,6 +639,8 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
}
Ok(())
}
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>>;
}
/// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set
@ -501,26 +665,8 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_jwt(
&self,
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(cookie) = req.cookie(self.cookie_name) else {
return Ok(())
};
let as_str = cookie.value();
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into)
}
}
@ -546,37 +692,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_jwt(
&self,
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(authorisation_header) = req
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
req
.headers()
.get(self.header_name)
else {
return Ok(())
};
let as_str = authorisation_header
.to_str()
.map_err(|_| Error::NoAuthHeader)?;
let as_str = as_str
.strip_prefix("Bearer ")
.or_else(|| as_str.strip_prefix("bearer "))
.unwrap_or(as_str);
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
.get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into)
}
}

View File

@ -10,6 +10,7 @@
use super::*;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
use futures_util::future::LocalBoxFuture;
use redis::aio::ConnectionLike;
use redis::AsyncCommands;
use std::future::{ready, Ready};
use std::marker::PhantomData;
@ -37,38 +38,37 @@ impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
where
ClaimsType: Claims,
{
type ClaimsType = ClaimsType;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = conn
.get::<_, Vec<u8>>(jti.as_bytes())
conn.get::<_, Vec<u8>>(jti)
.await
.map_err(|_| Error::NotFound)?;
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
.map_err(|_| Error::NotFound)
}
async fn set_by_jti(
self: Arc<Self>,
claims: Self::ClaimsType,
jwt_jti: &[u8],
refresh_jti: &[u8],
bytes: &[u8],
exp: std::time::Duration,
) -> Result<(), Error> {
let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize)
let mut pipeline = redis::Pipeline::new();
pipeline
.set_ex(jwt_jti, bytes, exp.as_secs() as usize)
.set_ex(refresh_jti, bytes, exp.as_secs() as usize);
conn.req_packed_commands(&pipeline, 0, 2)
.await
.map_err(|_| Error::WriteFailed)?;
Ok(())
}
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error> {
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
conn.del(jti.as_bytes())
.await
.map_err(|_| Error::NotFound)?;
conn.del(jti).await.map_err(|_| Error::NotFound)?;
Ok(())
}
}
@ -154,17 +154,13 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
pool: redis_async_pool::RedisPool,
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
) -> Self {
let storage = Arc::new(RedisStorage::new(pool));
let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
Self {
jwt_encoding_key: jwt_encoding_key.clone(),
jwt_decoding_key,
algorithm,
storage: SessionStorage {
storage,
jwt_encoding_key: jwt_encoding_key.clone(),
algorithm,
},
storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm),
extractors: Arc::new(extractors),
_claims_type_marker: Default::default(),
}