Improve session managment
This commit is contained in:
parent
4edccc7661
commit
38026a6f42
@ -125,10 +125,13 @@ use actix_web::{dev::ServiceRequest, HttpResponse};
|
|||||||
use actix_web::{FromRequest, HttpMessage};
|
use actix_web::{FromRequest, HttpMessage};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
||||||
|
use serde::Deserialize;
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::SystemTime;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
/// Default authorization header is "Authorization"
|
/// Default authorization header is "Authorization"
|
||||||
pub static DEFAULT_HEADER_NAME: &str = "Authorization";
|
pub static DEFAULT_HEADER_NAME: &str = "Authorization";
|
||||||
@ -142,6 +145,31 @@ pub trait Claims: PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync
|
|||||||
fn subject(&self) -> &str;
|
fn subject(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RefreshToken {
|
||||||
|
pub access_jti: uuid::Uuid,
|
||||||
|
pub access_exp: std::time::Duration,
|
||||||
|
|
||||||
|
pub refresh_jti: uuid::Uuid,
|
||||||
|
pub refresh_exp: std::time::Duration,
|
||||||
|
|
||||||
|
pub iat: std::time::SystemTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Claims for RefreshToken {
|
||||||
|
fn jti(&self) -> uuid::Uuid {
|
||||||
|
self.refresh_jti
|
||||||
|
}
|
||||||
|
fn subject(&self) -> &str {
|
||||||
|
"refresh-token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Pair<ClaimsType: Claims> {
|
||||||
|
pub jwt: Authenticated<ClaimsType>,
|
||||||
|
pub refresh: Authenticated<RefreshToken>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Session related errors
|
/// Session related errors
|
||||||
#[derive(Debug, thiserror::Error, PartialEq)]
|
#[derive(Debug, thiserror::Error, PartialEq)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
@ -159,6 +187,8 @@ pub enum Error {
|
|||||||
SerializeFailed,
|
SerializeFailed,
|
||||||
#[error("Unable to write claims to storage")]
|
#[error("Unable to write claims to storage")]
|
||||||
WriteFailed,
|
WriteFailed,
|
||||||
|
#[error("Access token expired")]
|
||||||
|
JWTExpired,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl actix_web::ResponseError for Error {
|
impl actix_web::ResponseError for Error {
|
||||||
@ -310,73 +340,193 @@ impl<T: Claims> FromRequest for MaybeAuthenticated<T> {
|
|||||||
/// postgresql.
|
/// postgresql.
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
pub trait TokenStorage: Send + Sync {
|
pub trait TokenStorage: Send + Sync {
|
||||||
type ClaimsType: Claims;
|
|
||||||
|
|
||||||
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
||||||
/// error while trying to fetch data from storage.
|
/// error while trying to fetch data from storage.
|
||||||
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
|
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error>;
|
||||||
|
|
||||||
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
||||||
async fn set_by_jti(
|
async fn set_by_jti(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
claims: Self::ClaimsType,
|
jwt_jti: &[u8],
|
||||||
|
refresh_jti: &[u8],
|
||||||
|
bytes: &[u8],
|
||||||
exp: std::time::Duration,
|
exp: std::time::Duration,
|
||||||
) -> Result<(), Error>;
|
) -> Result<(), Error>;
|
||||||
|
|
||||||
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
||||||
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
||||||
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error>;
|
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allow to save, read and remove session from storage.
|
/// Allow to save, read and remove session from storage.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SessionStorage<ClaimsType: Claims> {
|
pub struct SessionStorage<ClaimsType: Claims> {
|
||||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
storage: Arc<dyn TokenStorage>,
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
jwt_encoding_key: Arc<EncodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
|
__ty: PhantomData<ClaimsType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
|
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
|
||||||
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
|
type Target = Arc<dyn TokenStorage>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.storage
|
&self.storage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct SessionRecord {
|
||||||
|
refresh_jti: uuid::Uuid,
|
||||||
|
jwt_jti: uuid::Uuid,
|
||||||
|
refresh_token: Vec<u8>,
|
||||||
|
jwt: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionRecord {
|
||||||
|
fn new<ClaimsType: Claims>(claims: ClaimsType, refresh: RefreshToken) -> Result<Self, Error> {
|
||||||
|
let refresh_jti = claims.jti();
|
||||||
|
let jwt_jti = refresh.refresh_jti;
|
||||||
|
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
|
||||||
|
let jwt = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
|
||||||
|
Ok(Self {
|
||||||
|
refresh_jti,
|
||||||
|
jwt_jti,
|
||||||
|
refresh_token,
|
||||||
|
jwt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refresh_token(&self) -> Result<RefreshToken, Error> {
|
||||||
|
bincode::deserialize(&self.refresh_token).map_err(|_| Error::RecordMalformed)
|
||||||
|
}
|
||||||
|
fn jwt_token<ClaimsType: Claims>(&self) -> Result<ClaimsType, Error> {
|
||||||
|
bincode::deserialize(&self.jwt).map_err(|_| Error::RecordMalformed)
|
||||||
|
}
|
||||||
|
fn set_refresh_token(&mut self, refresh: RefreshToken) -> Result<(), Error> {
|
||||||
|
let refresh_token = bincode::serialize(&refresh).map_err(|_| Error::SerializeFailed)?;
|
||||||
|
self.refresh_token = refresh_token;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
||||||
pub async fn set_by_jti(
|
pub fn new(
|
||||||
&self,
|
storage: Arc<dyn TokenStorage>,
|
||||||
claims: ClaimsType,
|
jwt_encoding_key: Arc<EncodingKey>,
|
||||||
exp: std::time::Duration,
|
algorithm: Algorithm,
|
||||||
) -> Result<(), Error> {
|
) -> Self {
|
||||||
self.storage.clone().set_by_jti(claims, exp).await
|
Self {
|
||||||
|
storage,
|
||||||
|
jwt_encoding_key,
|
||||||
|
algorithm,
|
||||||
|
__ty: Default::default(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
/// Load claims from storage or returns [Error] if record does not exists or there was other
|
||||||
/// error while trying to fetch data from storage.
|
/// error while trying to fetch data from storage.
|
||||||
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
pub async fn find_jwt(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
||||||
self.storage.clone().get_from_jti(jti).await
|
let record = self.load_pair_by_jwt(jti).await?;
|
||||||
|
let refresh_token = record.refresh_token()?;
|
||||||
|
if refresh_token.iat + refresh_token.access_exp < std::time::SystemTime::now() {
|
||||||
|
return Err(Error::JWTExpired);
|
||||||
|
}
|
||||||
|
record.jwt_token()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn refresh(&self, refresh_jti: uuid::Uuid) -> Result<(), Error> {
|
||||||
|
let mut record = self.load_pair_by_refresh(refresh_jti).await?;
|
||||||
|
let mut refresh_token = record.refresh_token()?;
|
||||||
|
let exp = refresh_token.refresh_exp;
|
||||||
|
refresh_token.iat = SystemTime::now();
|
||||||
|
record.set_refresh_token(refresh_token)?;
|
||||||
|
self.store_pair(record, exp).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
/// Save claims in storage in a way claims can be loaded from database using `jti` as [uuid::Uuid] (JWT ID)
|
||||||
pub async fn store(
|
pub async fn store(
|
||||||
&self,
|
&self,
|
||||||
claims: ClaimsType,
|
claims: ClaimsType,
|
||||||
exp: std::time::Duration,
|
jwt_exp: std::time::Duration,
|
||||||
) -> Result<Authenticated<ClaimsType>, Error> {
|
refresh_exp: std::time::Duration,
|
||||||
self.set_by_jti(claims.clone(), exp).await?;
|
) -> Result<Pair<ClaimsType>, Error> {
|
||||||
Ok(Authenticated {
|
let refresh = RefreshToken {
|
||||||
claims: Arc::new(claims),
|
refresh_jti: uuid::Uuid::new_v4(),
|
||||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
refresh_exp,
|
||||||
algorithm: self.algorithm,
|
access_jti: claims.jti(),
|
||||||
|
access_exp: jwt_exp,
|
||||||
|
iat: std::time::SystemTime::now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let record = SessionRecord::new(claims.clone(), refresh.clone())?;
|
||||||
|
self.store_pair(record, refresh_exp).await?;
|
||||||
|
|
||||||
|
Ok(Pair {
|
||||||
|
jwt: Authenticated {
|
||||||
|
claims: Arc::new(claims),
|
||||||
|
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||||
|
algorithm: self.algorithm,
|
||||||
|
},
|
||||||
|
refresh: Authenticated {
|
||||||
|
claims: Arc::new(refresh),
|
||||||
|
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||||
|
algorithm: self.algorithm,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
/// Erase claims from storage. You may ignore if claims does not exists in storage.
|
||||||
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
/// Redis implementation returns [Error::NotFound] if record does not exists.
|
||||||
pub async fn erase(&self, jti: Uuid) -> Result<(), Error> {
|
pub async fn erase(&self, jti: Uuid) -> Result<(), Error> {
|
||||||
self.storage.clone().remove_by_jti(jti).await
|
let record = self.load_pair_by_jwt(jti).await?;
|
||||||
|
|
||||||
|
self.storage
|
||||||
|
.clone()
|
||||||
|
.remove_by_jti(record.refresh_jti.as_bytes())
|
||||||
|
.await?;
|
||||||
|
self.storage
|
||||||
|
.clone()
|
||||||
|
.remove_by_jti(record.jwt_jti.as_bytes())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_pair(
|
||||||
|
&self,
|
||||||
|
record: SessionRecord,
|
||||||
|
exp: std::time::Duration,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let value = bincode::serialize(&record).map_err(|_| Error::SerializeFailed)?;
|
||||||
|
|
||||||
|
self.storage
|
||||||
|
.clone()
|
||||||
|
.set_by_jti(
|
||||||
|
record.jwt_jti.as_bytes(),
|
||||||
|
record.refresh_jti.as_bytes(),
|
||||||
|
&value,
|
||||||
|
exp,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn load_pair_by_jwt(&self, jti: Uuid) -> Result<SessionRecord, Error> {
|
||||||
|
self.storage
|
||||||
|
.clone()
|
||||||
|
.get_by_jti(jti.as_bytes())
|
||||||
|
.await
|
||||||
|
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
|
||||||
|
}
|
||||||
|
async fn load_pair_by_refresh(&self, jti: Uuid) -> Result<SessionRecord, Error> {
|
||||||
|
self.storage
|
||||||
|
.clone()
|
||||||
|
.get_by_jti(jti.as_bytes())
|
||||||
|
.await
|
||||||
|
.and_then(|bytes| bincode::deserialize(&bytes).map_err(|_| Error::RecordMalformed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -440,7 +590,19 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
|||||||
jwt_decoding_key: Arc<DecodingKey>,
|
jwt_decoding_key: Arc<DecodingKey>,
|
||||||
algorithm: Algorithm,
|
algorithm: Algorithm,
|
||||||
storage: SessionStorage<ClaimsType>,
|
storage: SessionStorage<ClaimsType>,
|
||||||
) -> Result<(), Error>;
|
) -> Result<(), Error> {
|
||||||
|
let Some(as_str) = self.extract_token_text(req).await else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
let decoded_claims = self.decode(&*as_str, jwt_decoding_key, algorithm)?;
|
||||||
|
self.validate(&decoded_claims, storage).await?;
|
||||||
|
req.extensions_mut().insert(Authenticated {
|
||||||
|
claims: Arc::new(decoded_claims),
|
||||||
|
jwt_encoding_key,
|
||||||
|
algorithm,
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Decode encrypted JWT to structure
|
/// Decode encrypted JWT to structure
|
||||||
fn decode(
|
fn decode(
|
||||||
@ -468,7 +630,7 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
|||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let stored = storage
|
let stored = storage
|
||||||
.clone()
|
.clone()
|
||||||
.get_from_jti(claims.jti())
|
.find_jwt(claims.jti())
|
||||||
.await
|
.await
|
||||||
.map_err(|_| Error::InvalidSession)?;
|
.map_err(|_| Error::InvalidSession)?;
|
||||||
|
|
||||||
@ -477,6 +639,8 @@ pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set
|
/// Extracts JWT token from HTTP Request cookies. This extractor should be used when you can't set
|
||||||
@ -501,26 +665,8 @@ impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
|
|||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
|
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
|
||||||
async fn extract_jwt(
|
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
|
||||||
&self,
|
req.cookie(self.cookie_name).map(|c| c.value().to_string()).map(Into::into)
|
||||||
req: &ServiceRequest,
|
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
|
||||||
algorithm: Algorithm,
|
|
||||||
storage: SessionStorage<ClaimsType>,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let Some(cookie) = req.cookie(self.cookie_name) else {
|
|
||||||
return Ok(())
|
|
||||||
};
|
|
||||||
let as_str = cookie.value();
|
|
||||||
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
|
||||||
self.validate(&decoded_claims, storage).await?;
|
|
||||||
req.extensions_mut().insert(Authenticated {
|
|
||||||
claims: Arc::new(decoded_claims),
|
|
||||||
jwt_encoding_key,
|
|
||||||
algorithm,
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -546,37 +692,10 @@ impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
|
|||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait(?Send)]
|
||||||
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
|
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
|
||||||
async fn extract_jwt(
|
async fn extract_token_text<'req>(&self, req: &'req ServiceRequest) -> Option<Cow<'req, str>> {
|
||||||
&self,
|
req
|
||||||
req: &ServiceRequest,
|
|
||||||
jwt_encoding_key: Arc<EncodingKey>,
|
|
||||||
jwt_decoding_key: Arc<DecodingKey>,
|
|
||||||
algorithm: Algorithm,
|
|
||||||
storage: SessionStorage<ClaimsType>,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let Some(authorisation_header) = req
|
|
||||||
.headers()
|
.headers()
|
||||||
.get(self.header_name)
|
.get(self.header_name).and_then(|h| h.to_str().ok()).map(Into::into)
|
||||||
else {
|
|
||||||
return Ok(())
|
|
||||||
};
|
|
||||||
let as_str = authorisation_header
|
|
||||||
.to_str()
|
|
||||||
.map_err(|_| Error::NoAuthHeader)?;
|
|
||||||
|
|
||||||
let as_str = as_str
|
|
||||||
.strip_prefix("Bearer ")
|
|
||||||
.or_else(|| as_str.strip_prefix("bearer "))
|
|
||||||
.unwrap_or(as_str);
|
|
||||||
|
|
||||||
let decoded_claims = self.decode(as_str, jwt_decoding_key, algorithm)?;
|
|
||||||
self.validate(&decoded_claims, storage).await?;
|
|
||||||
req.extensions_mut().insert(Authenticated {
|
|
||||||
claims: Arc::new(decoded_claims),
|
|
||||||
jwt_encoding_key,
|
|
||||||
algorithm,
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
||||||
use futures_util::future::LocalBoxFuture;
|
use futures_util::future::LocalBoxFuture;
|
||||||
|
use redis::aio::ConnectionLike;
|
||||||
use redis::AsyncCommands;
|
use redis::AsyncCommands;
|
||||||
use std::future::{ready, Ready};
|
use std::future::{ready, Ready};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
@ -37,38 +38,37 @@ impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
|
|||||||
where
|
where
|
||||||
ClaimsType: Claims,
|
ClaimsType: Claims,
|
||||||
{
|
{
|
||||||
type ClaimsType = ClaimsType;
|
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error> {
|
||||||
|
|
||||||
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
|
||||||
let pool = self.pool.clone();
|
let pool = self.pool.clone();
|
||||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||||
let val = conn
|
conn.get::<_, Vec<u8>>(jti)
|
||||||
.get::<_, Vec<u8>>(jti.as_bytes())
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| Error::NotFound)?;
|
.map_err(|_| Error::NotFound)
|
||||||
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_by_jti(
|
async fn set_by_jti(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
claims: Self::ClaimsType,
|
jwt_jti: &[u8],
|
||||||
|
refresh_jti: &[u8],
|
||||||
|
bytes: &[u8],
|
||||||
exp: std::time::Duration,
|
exp: std::time::Duration,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let pool = self.pool.clone();
|
let pool = self.pool.clone();
|
||||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||||
let val = bincode::serialize(&claims).map_err(|_| Error::SerializeFailed)?;
|
let mut pipeline = redis::Pipeline::new();
|
||||||
conn.set_ex::<_, _, String>(claims.jti().as_bytes(), val, exp.as_secs() as usize)
|
pipeline
|
||||||
|
.set_ex(jwt_jti, bytes, exp.as_secs() as usize)
|
||||||
|
.set_ex(refresh_jti, bytes, exp.as_secs() as usize);
|
||||||
|
conn.req_packed_commands(&pipeline, 0, 2)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| Error::WriteFailed)?;
|
.map_err(|_| Error::WriteFailed)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_by_jti(self: Arc<Self>, jti: Uuid) -> Result<(), Error> {
|
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error> {
|
||||||
let pool = self.pool.clone();
|
let pool = self.pool.clone();
|
||||||
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
|
||||||
conn.del(jti.as_bytes())
|
conn.del(jti).await.map_err(|_| Error::NotFound)?;
|
||||||
.await
|
|
||||||
.map_err(|_| Error::NotFound)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,17 +154,13 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
|||||||
pool: redis_async_pool::RedisPool,
|
pool: redis_async_pool::RedisPool,
|
||||||
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
|
extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let storage = Arc::new(RedisStorage::new(pool));
|
let storage = Arc::new(RedisStorage::<ClaimsType>::new(pool));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
jwt_encoding_key: jwt_encoding_key.clone(),
|
jwt_encoding_key: jwt_encoding_key.clone(),
|
||||||
jwt_decoding_key,
|
jwt_decoding_key,
|
||||||
algorithm,
|
algorithm,
|
||||||
storage: SessionStorage {
|
storage: SessionStorage::new(storage, jwt_encoding_key.clone(), algorithm),
|
||||||
storage,
|
|
||||||
jwt_encoding_key: jwt_encoding_key.clone(),
|
|
||||||
algorithm,
|
|
||||||
},
|
|
||||||
extractors: Arc::new(extractors),
|
extractors: Arc::new(extractors),
|
||||||
_claims_type_marker: Default::default(),
|
_claims_type_marker: Default::default(),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user