Improve session managment
This commit is contained in:
parent
4edccc7661
commit
38026a6f42
@ -125,10 +125,13 @@ use actix_web::{dev::ServiceRequest, HttpResponse};
|
||||
use actix_web::{FromRequest, HttpMessage};
|
||||
use async_trait::async_trait;
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
||||
use serde::Deserialize;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use uuid::Uuid;
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Default authorization header is "Authorization"
|
||||
pub static DEFAULT_HEADER_NAME: &str = "Authorization";
|
||||
@ -142,6 +145,31 @@ pub trait Claims: PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync
|
||||
fn subject(&self) -> &str;
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct RefreshToken {
|
||||
pub access_jti: uuid::Uuid,
|
||||
pub access_exp: std::time::Duration,
|
||||
|
||||
pub refresh_jti: uuid::Uuid,
|
||||
pub refresh_exp: std::time::Duration,
|
||||
|
||||
pub iat: std::time::SystemTime,
|
||||
}
|
||||
|
||||
impl Claims for RefreshToken {
|
||||
fn jti(&self) -> uuid::Uuid {
|
||||
self.refresh_jti
|
||||
}
|
||||
fn subject(&self) -> &str {
|
||||
"refresh-token"
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Pair<ClaimsType: Claims> {
|
||||
pub jwt: Authenticated<ClaimsType>,
|
||||
pub refresh: Authenticated<RefreshToken>,
|
||||
}
|
||||
|
||||
/// Session related errors
|
||||
#[derive(Debug, thiserror::Error, PartialEq)]
|
||||
pub enum Error {
|
||||
@ -159,6 +187,8 @@ pub enum Error {
|
||||
SerializeFailed,
|
||||
#[error("Unable to write claims to storage")]
|
||||
WriteFailed,
|
||||
#[error("Access token expired")]
|
||||
JWTExpired,
|
||||
}
|
||||
|
||||
impl actix_web::ResponseError for Error {
|
||||
@ -310,73 +340,193 @@ impl<T: Claims> FromRequest for MaybeAuthenticated<T> {
|
||||
/// postgresql.
|
||||
#[async_trait(?Send)]
|
||||
pub trait TokenStorage: Send + Sync {
|
||||
type ClaimsType: Claims;
|
||||
|
||||
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
||||
/// error while trying to fetch data from storage.
|
||||
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
|
||||
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error>;
|
||||
|
||||
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
||||
async fn set_by_jti(
|
||||
self: Arc<Self>,
|
||||
claims: Self::ClaimsType,
|
||||
jwt_jti: &[u8],
|
||||
refresh_jti: &[u8],
|
||||
bytes: &[u8],
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error>;
|
||||
|
||||
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
||||
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
||||
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error>;
|
||||
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
/// Allow to save, read and remove session from storage.
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStorage<ClaimsType: Claims> {
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
storage: Arc<dyn TokenStorage>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
algorithm: Algorithm,
|
||||
__ty: PhantomData<ClaimsType>,
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
|
||||
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
|
||||
type Target = Arc<dyn TokenStorage>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct SessionRecord {
|
||||
refresh_jti: uuid::Uuid,
|
||||
jwt_jti: uuid::Uuid,
|
||||
refresh_token: Vec<u8>,
|
||||
jwt: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SessionRecord {
|
||||
fn new<ClaimsType: Claims>(claims: ClaimsType, refresh: RefreshToken) -> Result<Self, Error> {
|
||||
let refresh_jti = claims.jti();
|
||||
let jwt_jti = refresh.refresh_jti;
|
||||
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
|
||||
let jwt = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
|
||||
Ok(Self {
|
||||
refresh_jti,
|
||||
jwt_jti,
|
||||
refresh_token,
|
||||
jwt,
|
||||
})
|
||||
}
|
||||
|
||||
fn refresh_token(&self) -> Result<RefreshToken, Error> {
|
||||
bincode::deserialize(&self.refresh_token).map_err(|_| Error::RecordMalformed)
|
||||
}
|
||||
fn jwt_token<ClaimsType: Claims>(&self) -> Result<ClaimsType, Error> {
|
||||
bincode::deserialize(&self.jwt).map_err(|_| Error::RecordMalformed)
|
||||
}
|
||||
fn set_refresh_token(&mut self, refresh: RefreshToken) -> Result<(), Error> {
|
||||
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
|
||||
self.refresh_token = refresh_token;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
||||
pub async fn set_by_jti(
|
||||
&self,
|
||||
claims: ClaimsType,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error> {
|
||||
self.storage.clone().set_by_jti(claims, exp).await
|
||||
pub fn new(
|
||||
storage: Arc<dyn TokenStorage>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
algorithm: Algorithm,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
__ty: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
||||
/// error while trying to fetch data from storage.
|
||||
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
||||
self.storage.clone().get_from_jti(jti).await
|
||||
pub async fn find_jwt(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
||||
let record = self.load_pair_by_jwt(jti).await?;
|
||||
let refresh_token = record.refresh_token()?;
|
||||
if refresh_token.iat + refresh_token.access_exp < std::time::SystemTime::now() {
|
||||
return Err(Error::JWTExpired);
|
||||
}
|
||||
record.jwt_token()
|
||||
}
|
||||
|
||||
pub async fn refresh(&self, refresh_jti: uuid::Uuid) -> Result<(), Error> {
|
||||
let mut record = self.load_pair_by_refresh(refresh_jti).await?;
|
||||
let mut refresh_token = record.refresh_token()?;
|
||||
let exp = refresh_token.refresh_exp;
|
||||
refresh_token.iat = SystemTime::now();
|
||||
record.set_refresh_token(refresh_token)?;
|
||||
self.store_pair(record, exp).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
||||
pub async fn store(
|
||||
&self,
|
||||
claims: ClaimsType,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<Authenticated<ClaimsType>, Error> {
|
||||
self.set_by_jti(claims.clone(), exp).await?;
|
||||
Ok(Authenticated {
|
||||
jwt_exp: std::time::Duration,
|
||||
refresh_exp: std::time::Duration,
|
||||
) -> Result<Pair<ClaimsType>, Error> {
|
||||
let refresh = RefreshToken {
|
||||
refresh_jti: uuid::Uuid::new_v4(),
|
||||
refresh_exp,
|
||||
access_jti: claims.jti(),
|
||||
access_exp: jwt_exp,
|
||||
iat: std::time::SystemTime::now(),
|
||||
};
|
||||
|
||||
let record = SessionRecord::new(claims.clone(), refresh.clone())?;
|
||||
self.store_pair(record, refresh_exp).await?;
|
||||
|
||||
Ok(Pair {
|
||||
jwt: Authenticated {
|
||||
claims: Arc::new(claims),
|
||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||
algorithm: self.algorithm,
|
||||
},
|
||||
refresh: Authenticated {
|
||||
claims: Arc::new(refresh),
|
||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||
algorithm: self.algorithm,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
||||
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
||||
pub async fn erase(&self, jti: Uuid) -> Result<(), Error> {
|
||||
self.storage.clone().remove_by_jti(jti).await
|
||||
let record = self.load_pair_by_jwt(jti).await?;
|
||||
|
||||
self.storage
|
||||
.clone()
|
||||
.remove_by_jti(record.refresh_jti.as_bytes())
|
||||
.await?;
|
||||
self.storage
|
||||
.clone()
|
||||
.remove_by_jti(record.jwt_jti.as_bytes())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_pair(
|
||||
&self,
|
||||
record: SessionRecord,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error> {
|
||||
let value = bincode::serialize(&record).map_err(|_| Error::SerializeFailed)?;
|
||||
|
||||
self.storage
|
||||
.clone()
|
||||
.set_by_jti(
|
||||
record.jwt_jti.as_bytes(),
|
||||
record.refresh_jti.as_bytes(),
|
||||
&value,
|
||||
exp,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn load_pair_by_jwt(&self, jti: Uuid) -> Result<SessionRecord, Error> {
|
||||
self.storage
|
||||
.clone()
|
||||
.get_by_jti(jti.as_bytes())
|
||||
.await
|
||||
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
|
||||
}
|
||||
async fn load_pair_by_refresh(&self, jti: Uuid) -> Result<SessionRecord, Error> {
|
||||
self.storage
|
||||
.clone()
|
||||
.get_by_jti(jti.as_bytes())
|
||||
.await
|
||||
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
|
||||
}
|
||||
}
|
||||
|
||||
@ -440,7 +590,19 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
) -> Result<(), Error>;
|
||||
) -> Result<(), Error> {
|
||||
let Some(as_str) = self.extract_token_text(req).await else {
|
||||
return Ok(());
|
||||
};
|
||||
let decoded_claims = self.decode(&*as_str, jwt_decoding_key, algorithm)?;
|
||||
self.validate(&decoded_claims, storage).await?;
|
||||
req.extensions_mut().insert(Authenticated {
|
||||
claims: Arc::new(decoded_claims),
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Decode encrypted JWT to structure
|
||||
fn decode(
|
||||
@ -468,7 +630,7 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
||||
) -> Result<(), Error> {
|
||||
let stored = storage
|
||||
.clone()
|
||||
.get_from_jti(claims.jti())
|
||||
.find_jwt(claims.jti())
|
||||
.await
|
||||
.map_err(|_| Error::InvalidSession)?;
|
||||
|
||||
@ -477,6 +639,8 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>>;
|
||||
}
|
||||
|
||||
/// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set
|
||||
@ -501,26 +665,8 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
|
||||
async fn extract_jwt(
|
||||
&self,
|
||||
req: &ServiceRequest,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
) -> Result<(), Error> {
|
||||
let Some(cookie) = req.cookie(self.cookie_name) else {
|
||||
return Ok(())
|
||||
};
|
||||
let as_str = cookie.value();
|
||||
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
||||
self.validate(&decoded_claims, storage).await?;
|
||||
req.extensions_mut().insert(Authenticated {
|
||||
claims: Arc::new(decoded_claims),
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
});
|
||||
Ok(())
|
||||
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
|
||||
req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@ -546,37 +692,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
|
||||
async fn extract_jwt(
|
||||
&self,
|
||||
req: &ServiceRequest,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
) -> Result<(), Error> {
|
||||
let Some(authorisation_header) = req
|
||||
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
|
||||
req
|
||||
.headers()
|
||||
.get(self.header_name)
|
||||
else {
|
||||
return Ok(())
|
||||
};
|
||||
let as_str = authorisation_header
|
||||
.to_str()
|
||||
.map_err(|_| Error::NoAuthHeader)?;
|
||||
|
||||
let as_str = as_str
|
||||
.strip_prefix("Bearer ")
|
||||
.or_else(|| as_str.strip_prefix("bearer "))
|
||||
.unwrap_or(as_str);
|
||||
|
||||
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
||||
self.validate(&decoded_claims, storage).await?;
|
||||
req.extensions_mut().insert(Authenticated {
|
||||
claims: Arc::new(decoded_claims),
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
});
|
||||
Ok(())
|
||||
.get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
use super::*;
|
||||
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
||||
use futures_util::future::LocalBoxFuture;
|
||||
use redis::aio::ConnectionLike;
|
||||
use redis::AsyncCommands;
|
||||
use std::future::{ready, Ready};
|
||||
use std::marker::PhantomData;
|
||||
@ -37,38 +38,37 @@ impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
|
||||
where
|
||||
ClaimsType: Claims,
|
||||
{
|
||||
type ClaimsType = ClaimsType;
|
||||
|
||||
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
||||
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
let pool = self.pool.clone();
|
||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||
let val = conn
|
||||
.get::<_, Vec<u8>>(jti.as_bytes())
|
||||
conn.get::<_, Vec<u8>>(jti)
|
||||
.await
|
||||
.map_err(|_| Error::NotFound)?;
|
||||
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
|
||||
.map_err(|_| Error::NotFound)
|
||||
}
|
||||
|
||||
async fn set_by_jti(
|
||||
self: Arc<Self>,
|
||||
claims: Self::ClaimsType,
|
||||
jwt_jti: &[u8],
|
||||
refresh_jti: &[u8],
|
||||
bytes: &[u8],
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error> {
|
||||
let pool = self.pool.clone();
|
||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||
let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
|
||||
conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize)
|
||||
let mut pipeline = redis::Pipeline::new();
|
||||
pipeline
|
||||
.set_ex(jwt_jti, bytes, exp.as_secs() as usize)
|
||||
.set_ex(refresh_jti, bytes, exp.as_secs() as usize);
|
||||
conn.req_packed_commands(&pipeline, 0, 2)
|
||||
.await
|
||||
.map_err(|_| Error::WriteFailed)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error> {
|
||||
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
|
||||
let pool = self.pool.clone();
|
||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||
conn.del(jti.as_bytes())
|
||||
.await
|
||||
.map_err(|_| Error::NotFound)?;
|
||||
conn.del(jti).await.map_err(|_| Error::NotFound)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -154,17 +154,13 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
||||
pool: redis_async_pool::RedisPool,
|
||||
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
|
||||
) -> Self {
|
||||
let storage = Arc::new(RedisStorage::new(pool));
|
||||
let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
|
||||
|
||||
Self {
|
||||
jwt_encoding_key: jwt_encoding_key.clone(),
|
||||
jwt_decoding_key,
|
||||
algorithm,
|
||||
storage: SessionStorage {
|
||||
storage,
|
||||
jwt_encoding_key: jwt_encoding_key.clone(),
|
||||
algorithm,
|
||||
},
|
||||
storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm),
|
||||
extractors: Arc::new(extractors),
|
||||
_claims_type_marker: Default::default(),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user