Compare commits

...

2 Commits

2 changed files with 210 additions and 95 deletions

View File

@ -125,10 +125,13 @@ use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::{FromRequest, HttpMessage}; use actix_web::{FromRequest, HttpMessage};
use async_trait::async_trait; use async_trait::async_trait;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::Deserialize;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime;
use uuid::Uuid; use uuid::Uuid;
use std::borrow::Cow;
/// Default authorization header is "Authorization" /// Default authorization header is "Authorization"
pub static DEFAULT_HEADER_NAME: &str = "Authorization"; pub static DEFAULT_HEADER_NAME: &str = "Authorization";
@ -142,6 +145,31 @@ pub trait Claims: PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync
fn subject(&self) -> &str; fn subject(&self) -> &str;
} }
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct RefreshToken {
pub access_jti: uuid::Uuid,
pub access_exp: std::time::Duration,
pub refresh_jti: uuid::Uuid,
pub refresh_exp: std::time::Duration,
pub iat: std::time::SystemTime,
}
impl Claims for RefreshToken {
fn jti(&self) -> uuid::Uuid {
self.refresh_jti
}
fn subject(&self) -> &str {
"refresh-token"
}
}
pub struct Pair<ClaimsType: Claims> {
pub jwt: Authenticated<ClaimsType>,
pub refresh: Authenticated<RefreshToken>,
}
/// Session related errors /// Session related errors
#[derive(Debug, thiserror::Error, PartialEq)] #[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error { pub enum Error {
@ -159,6 +187,8 @@ pub enum Error {
SerializeFailed, SerializeFailed,
#[error("Unable to write claims to storage")] #[error("Unable to write claims to storage")]
WriteFailed, WriteFailed,
#[error("Access token expired")]
JWTExpired,
} }
impl actix_web::ResponseError for Error { impl actix_web::ResponseError for Error {
@ -310,73 +340,193 @@ impl<T: Claims> FromRequest for MaybeAuthenticated<T> {
/// postgresql. /// postgresql.
#[async_trait(?Send)] #[async_trait(?Send)]
pub trait TokenStorage: Send + Sync { pub trait TokenStorage: Send + Sync {
type ClaimsType: Claims;
/// Load claims from storage or returns [Error] if record does not exists or there was other /// Load claims from storage or returns [Error] if record does not exists or there was other
/// error while trying to fetch data from storage. /// error while trying to fetch data from storage.
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>; async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error>;
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID) /// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
async fn set_by_jti( async fn set_by_jti(
self: Arc<Self>, self: Arc<Self>,
claims: Self::ClaimsType, jwt_jti: &[u8],
refresh_jti: &[u8],
bytes: &[u8],
exp: std::time::Duration, exp: std::time::Duration,
) -> Result<(), Error>; ) -> Result<(), Error>;
/// Erase claims from storage. You may ignore if claims does not exists in storage. /// Erase claims from storage. You may ignore if claims does not exists in storage.
/// Redis implementation returns [Error::NotFound] if record does not exists. /// Redis implementation returns [Error::NotFound] if record does not exists.
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error>; async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error>;
} }
/// Allow to save, read and remove session from storage. /// Allow to save, read and remove session from storage.
#[derive(Clone)] #[derive(Clone)]
pub struct SessionStorage<ClaimsType: Claims> { pub struct SessionStorage<ClaimsType: Claims> {
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>, storage: Arc<dyn TokenStorage>,
jwt_encoding_key: Arc<EncodingKey>, jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm, algorithm: Algorithm,
__ty: PhantomData<ClaimsType>,
} }
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> { impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>; type Target = Arc<dyn TokenStorage>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.storage &self.storage
} }
} }
#[derive(Serialize, Deserialize, Clone)]
pub struct SessionRecord {
refresh_jti: uuid::Uuid,
jwt_jti: uuid::Uuid,
refresh_token: Vec<u8>,
jwt: Vec<u8>,
}
impl SessionRecord {
fn new<ClaimsType: Claims>(claims: ClaimsType, refresh: RefreshToken) -> Result<Self, Error> {
let refresh_jti = claims.jti();
let jwt_jti = refresh.refresh_jti;
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
let jwt = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
Ok(Self {
refresh_jti,
jwt_jti,
refresh_token,
jwt,
})
}
fn refresh_token(&self) -> Result<RefreshToken, Error> {
bincode::deserialize(&self.refresh_token).map_err(|_| Error::RecordMalformed)
}
fn jwt_token<ClaimsType: Claims>(&self) -> Result<ClaimsType, Error> {
bincode::deserialize(&self.jwt).map_err(|_| Error::RecordMalformed)
}
fn set_refresh_token(&mut self, refresh: RefreshToken) -> Result<(), Error> {
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
self.refresh_token = refresh_token;
Ok(())
}
}
impl<ClaimsType: Claims> SessionStorage<ClaimsType> { impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
pub async fn set_by_jti( pub fn new(
&self, storage: Arc<dyn TokenStorage>,
claims: ClaimsType, jwt_encoding_key: Arc<EncodingKey>,
exp: std::time::Duration, algorithm: Algorithm,
) -> Result<(), Error> { ) -> Self {
self.storage.clone().set_by_jti(claims, exp).await Self {
storage,
jwt_encoding_key,
algorithm,
__ty: Default::default(),
}
} }
/// Load claims from storage or returns [Error] if record does not exists or there was other /// Load claims from storage or returns [Error] if record does not exists or there was other
/// error while trying to fetch data from storage. /// error while trying to fetch data from storage.
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> { pub async fn find_jwt(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
self.storage.clone().get_from_jti(jti).await let record = self.load_pair_by_jwt(jti).await?;
let refresh_token = record.refresh_token()?;
if refresh_token.iat + refresh_token.access_exp < std::time::SystemTime::now() {
return Err(Error::JWTExpired);
}
record.jwt_token()
}
pub async fn refresh(&self, refresh_jti: uuid::Uuid) -> Result<(), Error> {
let mut record = self.load_pair_by_refresh(refresh_jti).await?;
let mut refresh_token = record.refresh_token()?;
let exp = refresh_token.refresh_exp;
refresh_token.iat = SystemTime::now();
record.set_refresh_token(refresh_token)?;
self.store_pair(record, exp).await?;
Ok(())
} }
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID) /// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
pub async fn store( pub async fn store(
&self, &self,
claims: ClaimsType, claims: ClaimsType,
exp: std::time::Duration, jwt_exp: std::time::Duration,
) -> Result<Authenticated<ClaimsType>, Error> { refresh_exp: std::time::Duration,
self.set_by_jti(claims.clone(), exp).await?; ) -> Result<Pair<ClaimsType>, Error> {
Ok(Authenticated { let refresh = RefreshToken {
claims: Arc::new(claims), refresh_jti: uuid::Uuid::new_v4(),
jwt_encoding_key: self.jwt_encoding_key.clone(), refresh_exp,
algorithm: self.algorithm, access_jti: claims.jti(),
access_exp: jwt_exp,
iat: std::time::SystemTime::now(),
};
let record = SessionRecord::new(claims.clone(), refresh.clone())?;
self.store_pair(record, refresh_exp).await?;
Ok(Pair {
jwt: Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
refresh: Authenticated {
claims: Arc::new(refresh),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
}) })
} }
/// Erase claims from storage. You may ignore if claims does not exists in storage. /// Erase claims from storage. You may ignore if claims does not exists in storage.
/// Redis implementation returns [Error::NotFound] if record does not exists. /// Redis implementation returns [Error::NotFound] if record does not exists.
pub async fn erase(&self, jti: Uuid) -> Result<(), Error> { pub async fn erase(&self, jti: Uuid) -> Result<(), Error> {
self.storage.clone().remove_by_jti(jti).await let record = self.load_pair_by_jwt(jti).await?;
self.storage
.clone()
.remove_by_jti(record.refresh_jti.as_bytes())
.await?;
self.storage
.clone()
.remove_by_jti(record.jwt_jti.as_bytes())
.await?;
Ok(())
}
async fn store_pair(
&self,
record: SessionRecord,
exp: std::time::Duration,
) -> Result<(), Error> {
let value = bincode::serialize(&record).map_err(|_| Error::SerializeFailed)?;
self.storage
.clone()
.set_by_jti(
record.jwt_jti.as_bytes(),
record.refresh_jti.as_bytes(),
&value,
exp,
)
.await?;
Ok(())
}
async fn load_pair_by_jwt(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
}
async fn load_pair_by_refresh(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
} }
} }
@ -440,7 +590,19 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
jwt_decoding_key: Arc<DecodingKey>, jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm, algorithm: Algorithm,
storage: SessionStorage<ClaimsType>, storage: SessionStorage<ClaimsType>,
) -> Result<(), Error>; ) -> Result<(), Error> {
let Some(as_str) = self.extract_token_text(req).await else {
return Ok(());
};
let decoded_claims = self.decode(&*as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
}
/// Decode encrypted JWT to structure /// Decode encrypted JWT to structure
fn decode( fn decode(
@ -468,7 +630,7 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
) -> Result<(), Error> { ) -> Result<(), Error> {
let stored = storage let stored = storage
.clone() .clone()
.get_from_jti(claims.jti()) .find_jwt(claims.jti())
.await .await
.map_err(|_| Error::InvalidSession)?; .map_err(|_| Error::InvalidSession)?;
@ -477,6 +639,8 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
} }
Ok(()) Ok(())
} }
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>>;
} }
/// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set /// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set
@ -501,26 +665,8 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
#[async_trait(?Send)] #[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> { impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_jwt( async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
&self, req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into)
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(cookie) = req.cookie(self.cookie_name) else {
return Ok(())
};
let as_str = cookie.value();
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
} }
} }
@ -546,37 +692,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
#[async_trait(?Send)] #[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> { impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_jwt( async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
&self, req
req: &ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(authorisation_header) = req
.headers() .headers()
.get(self.header_name) .get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into)
else {
return Ok(())
};
let as_str = authorisation_header
.to_str()
.map_err(|_| Error::NoAuthHeader)?;
let as_str = as_str
.strip_prefix("Bearer ")
.or_else(|| as_str.strip_prefix("bearer "))
.unwrap_or(as_str);
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
} }
} }

View File

@ -10,6 +10,7 @@
use super::*; use super::*;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
use futures_util::future::LocalBoxFuture; use futures_util::future::LocalBoxFuture;
use redis::aio::ConnectionLike;
use redis::AsyncCommands; use redis::AsyncCommands;
use std::future::{ready, Ready}; use std::future::{ready, Ready};
use std::marker::PhantomData; use std::marker::PhantomData;
@ -37,38 +38,37 @@ impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
where where
ClaimsType: Claims, ClaimsType: Claims,
{ {
type ClaimsType = ClaimsType; async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
let pool = self.pool.clone(); let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?; let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = conn conn.get::<_, Vec<u8>>(jti)
.get::<_, Vec<u8>>(jti.as_bytes())
.await .await
.map_err(|_| Error::NotFound)?; .map_err(|_| Error::NotFound)
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
} }
async fn set_by_jti( async fn set_by_jti(
self: Arc<Self>, self: Arc<Self>,
claims: Self::ClaimsType, jwt_jti: &[u8],
refresh_jti: &[u8],
bytes: &[u8],
exp: std::time::Duration, exp: std::time::Duration,
) -> Result<(), Error> { ) -> Result<(), Error> {
let pool = self.pool.clone(); let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?; let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?; let mut pipeline = redis::Pipeline::new();
conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize) pipeline
.set_ex(jwt_jti, bytes, exp.as_secs() as usize)
.set_ex(refresh_jti, bytes, exp.as_secs() as usize);
conn.req_packed_commands(&pipeline, 0, 2)
.await .await
.map_err(|_| Error::WriteFailed)?; .map_err(|_| Error::WriteFailed)?;
Ok(()) Ok(())
} }
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error> { async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
let pool = self.pool.clone(); let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?; let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
conn.del(jti.as_bytes()) conn.del(jti).await.map_err(|_| Error::NotFound)?;
.await
.map_err(|_| Error::NotFound)?;
Ok(()) Ok(())
} }
} }
@ -154,17 +154,13 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
pool: redis_async_pool::RedisPool, pool: redis_async_pool::RedisPool,
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>, extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
) -> Self { ) -> Self {
let storage = Arc::new(RedisStorage::new(pool)); let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
Self { Self {
jwt_encoding_key: jwt_encoding_key.clone(), jwt_encoding_key: jwt_encoding_key.clone(),
jwt_decoding_key, jwt_decoding_key,
algorithm, algorithm,
storage: SessionStorage { storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm),
storage,
jwt_encoding_key: jwt_encoding_key.clone(),
algorithm,
},
extractors: Arc::new(extractors), extractors: Arc::new(extractors),
_claims_type_marker: Default::default(), _claims_type_marker: Default::default(),
} }