New JWT implementation

This commit is contained in:
Adrian Woźniak 2023-08-11 15:25:36 +02:00
parent 26dea34054
commit 6692df9aeb
4 changed files with 209 additions and 0 deletions

View File

@ -0,0 +1,23 @@
[package]
name = "actix-jwt-session"
version = "0.1.0"
edition = "2021"
[features]
default = ['use-redis']
use-redis = ["redis", "redis-async-pool"]
[dependencies]
actix-web = "4"
async-trait = "0.1.72"
bincode = "1.3.3"
futures = "0.3.28"
futures-lite = "1.13.0"
futures-util = { version = "0.3.28", features = ['async-await'] }
jsonwebtoken = "8.3.0"
redis = { version = "0.17", optional = true }
redis-async-pool = { version = "0.2.4", optional = true }
serde = { version = "1.0.183", features = ["derive"] }
thiserror = "1.0.44"
tokio = { version = "1.30.0", features = ["full"] }
uuid = { version = "1.4.1", features = ["v4"] }

View File

@ -0,0 +1,85 @@
use actix_web::dev::ServiceRequest;
use actix_web::HttpMessage;
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
pub trait Claims: PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync + 'static {
fn jti(&self) -> uuid::Uuid;
}
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error {
#[error("Failed to obtain redis connection")]
RedisConn,
#[error("Record not found")]
NotFound,
#[error("Record malformed")]
RecordMalformed,
#[error("Invalid session")]
InvalidSession,
#[error("No http authentication header")]
NoAuthHeader,
}
impl actix_web::ResponseError for Error {
fn status_code(&self) -> actix_web::http::StatusCode {
match self {
Self::RedisConn => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
_ => actix_web::http::StatusCode::UNAUTHORIZED,
}
}
}
pub struct Authenticated<T>(Arc<T>);
#[async_trait::async_trait(?Send)]
pub trait TokenStorage {
type ClaimsType: Claims;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
}
struct Extractor;
impl Extractor {
async fn extract_bearer_jwt<ClaimsType: Claims>(
req: &ServiceRequest,
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
) -> Result<(), Error> {
let authorisation_header = req
.headers()
.get("Authorization")
.ok_or(Error::NoAuthHeader)?;
let as_str = authorisation_header
.to_str()
.map_err(|_| Error::NoAuthHeader)?;
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &*jwt_validator)
.map_err(|_e| {
// let error_message = e.to_string();
Error::InvalidSession
})?;
let stored = storage
.get_from_jti(decoded_claims.claims.jti())
.await
.map_err(|_| Error::InvalidSession)?;
if stored != decoded_claims.claims {
return Err(Error::InvalidSession);
}
req.extensions_mut()
.insert(Authenticated(Arc::new(decoded_claims.claims)));
Ok(())
}
}
#[cfg(feature = "redis")]
mod redis_adapter;
#[cfg(feature = "redis")]
pub use redis_adapter::*;

View File

@ -0,0 +1,100 @@
use super::*;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse};
use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{DecodingKey, Validation};
use redis::AsyncCommands;
use std::marker::PhantomData;
use std::rc::Rc;
use std::sync::Arc;
#[derive(Clone)]
pub struct RedisStorage<ClaimsType: Claims> {
pool: redis_async_pool::RedisPool,
_claims_type_marker: PhantomData<ClaimsType>,
}
impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
pub fn new(pool: redis_async_pool::RedisPool) -> Self {
Self {
pool,
_claims_type_marker: Default::default(),
}
}
}
#[async_trait::async_trait(?Send)]
impl<ClaimsType> TokenStorage for RedisStorage<ClaimsType>
where
ClaimsType: Claims,
{
type ClaimsType = ClaimsType;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
let pool = self.pool.clone();
let mut conn = pool.get().await.map_err(|_| Error::RedisConn)?;
let val = conn
.get::<_, Vec<u8>>(jti.as_bytes())
.await
.map_err(|_| Error::NotFound)?;
bincode::deserialize(&val).map_err(|_| Error::RecordMalformed)
}
}
pub struct RedisMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
{
_claims_type_marker: std::marker::PhantomData<ClaimsType>,
service: Rc<S>,
jwt_decoding_key: Arc<DecodingKey>,
jwt_validator: Arc<Validation>,
storage: Arc<RedisStorage<ClaimsType>>,
}
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
use futures_lite::FutureExt;
let svc = self.service.clone();
let jwt_decoding_key = self.jwt_decoding_key.clone();
let validation = self.jwt_validator.clone();
let storage = self.storage.clone();
async move {
Extractor::extract_bearer_jwt(&req, jwt_decoding_key, validation, storage).await?;
let res = svc.call(req).await?;
Ok(res)
}
.boxed_local()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
struct Out {
id: uuid::Uuid,
}
impl Claims for Out {
fn jti(&self) -> uuid::Uuid {
self.id
}
}
#[tokio::test]
async fn extract() {
}
}

View File

@ -0,0 +1 @@