Using session

This commit is contained in:
Adrian Woźniak 2023-08-14 12:30:32 +02:00
parent 5560f068b1
commit f265d22b87
7 changed files with 186 additions and 99 deletions

View File

@ -1,6 +1,6 @@
use actix_web::{dev::ServiceRequest, HttpResponse}; use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::{FromRequest, HttpMessage}; use actix_web::{FromRequest, HttpMessage};
use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc; use std::sync::Arc;
@ -58,7 +58,11 @@ impl<T> std::ops::Deref for Authenticated<T> {
impl<T: Claims> Authenticated<T> { impl<T: Claims> Authenticated<T> {
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> { pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key) encode(
&jsonwebtoken::Header::new(self.algorithm),
&*self.claims,
&*self.jwt_encoding_key,
)
} }
} }
@ -79,7 +83,7 @@ impl<T: Claims> FromRequest for Authenticated<T> {
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
pub trait TokenStorage { pub trait TokenStorage: Send + Sync {
type ClaimsType: Claims; type ClaimsType: Claims;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>; async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
@ -89,6 +93,48 @@ pub trait TokenStorage {
claims: Self::ClaimsType, claims: Self::ClaimsType,
exp: std::time::Duration, exp: std::time::Duration,
) -> Result<(), Error>; ) -> Result<(), Error>;
fn jwt_encoding_key(&self) -> Arc<EncodingKey>;
fn algorithm(&self) -> Algorithm;
}
#[derive(Clone)]
pub struct SessionStorage<ClaimsType: Claims>(Arc<dyn TokenStorage<ClaimsType = ClaimsType>>);
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
pub async fn set_by_jti(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error> {
self.0.clone().set_by_jti(claims, exp).await
}
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
self.0.clone().get_from_jti(jti).await
}
pub async fn store(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<Authenticated<ClaimsType>, Error> {
self.set_by_jti(claims.clone(), exp).await?;
Ok(Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.0.jwt_encoding_key(),
algorithm: self.algorithm(),
})
}
} }
struct Extractor; struct Extractor;
@ -99,7 +145,7 @@ impl Extractor {
jwt_encoding_key: Arc<EncodingKey>, jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>, jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm, algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>, storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let Some(authorisation_header) = req let Some(authorisation_header) = req
.headers() .headers()
@ -111,13 +157,17 @@ impl Extractor {
.to_str() .to_str()
.map_err(|_| Error::NoAuthHeader)?; .map_err(|_| Error::NoAuthHeader)?;
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)) let decoded_claims =
.map_err(|_e| { decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
// let error_message = e.to_string(); |_e| {
Error::InvalidSession // let error_message = e.to_string();
})?; Error::InvalidSession
},
)?;
let stored = storage let stored = storage
.0
.clone()
.get_from_jti(decoded_claims.claims.jti()) .get_from_jti(decoded_claims.claims.jti())
.await .await
.map_err(|_| Error::InvalidSession)?; .map_err(|_| Error::InvalidSession)?;
@ -126,10 +176,11 @@ impl Extractor {
return Err(Error::InvalidSession); return Err(Error::InvalidSession);
} }
req.extensions_mut() req.extensions_mut().insert(Authenticated {
.insert(Authenticated { claims: Arc::new(decoded_claims.claims),
claims: Arc::new(decoded_claims.claims), jwt_encoding_key,
}); algorithm,
});
Ok(()) Ok(())
} }
} }

View File

@ -1,7 +1,6 @@
use super::*; use super::*;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
use futures_util::future::LocalBoxFuture; use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{DecodingKey, Validation};
use redis::AsyncCommands; use redis::AsyncCommands;
use std::future::{ready, Ready}; use std::future::{ready, Ready};
use std::marker::PhantomData; use std::marker::PhantomData;
@ -9,15 +8,23 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
#[derive(Clone)] #[derive(Clone)]
pub struct RedisStorage<ClaimsType: Claims> { struct RedisStorage<ClaimsType: Claims> {
pool: redis_async_pool::RedisPool, pool: redis_async_pool::RedisPool,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
_claims_type_marker: PhantomData<ClaimsType>, _claims_type_marker: PhantomData<ClaimsType>,
} }
impl<ClaimsType: Claims> RedisStorage<ClaimsType> { impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
pub fn new(pool: redis_async_pool::RedisPool) -> Self { pub fn new(
pool: redis_async_pool::RedisPool,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
) -> Self {
Self { Self {
pool, pool,
jwt_encoding_key,
algorithm,
_claims_type_marker: Default::default(), _claims_type_marker: Default::default(),
} }
} }
@ -53,6 +60,14 @@ where
.map_err(|_| Error::WriteFailed)?; .map_err(|_| Error::WriteFailed)?;
Ok(()) Ok(())
} }
fn jwt_encoding_key(&self) -> Arc<EncodingKey> {
self.jwt_encoding_key.clone()
}
fn algorithm(&self) -> Algorithm {
self.algorithm
}
} }
pub struct RedisMiddleware<S, ClaimsType> pub struct RedisMiddleware<S, ClaimsType>
@ -61,10 +76,10 @@ where
{ {
_claims_type_marker: std::marker::PhantomData<ClaimsType>, _claims_type_marker: std::marker::PhantomData<ClaimsType>,
service: Rc<S>, service: Rc<S>,
jwt_encoding_key: Arc<EncodingKey>, jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>, jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm, algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>, storage: SessionStorage<ClaimsType>,
} }
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType> impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
@ -88,7 +103,14 @@ where
let storage = self.storage.clone(); let storage = self.storage.clone();
async move { async move {
Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?; Extractor::extract_bearer_jwt(
&req,
jwt_encoding_key,
jwt_decoding_key,
algorithm,
storage,
)
.await?;
let res = svc.call(req).await?; let res = svc.call(req).await?;
Ok(res) Ok(res)
} }
@ -98,10 +120,10 @@ where
#[derive(Clone)] #[derive(Clone)]
pub struct RedisMiddlewareFactory<ClaimsType: Claims> { pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
jwt_encoding_key: Arc<EncodingKey>, jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>, jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm, algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>, storage: SessionStorage<ClaimsType>,
_claims_type_marker: PhantomData<ClaimsType>, _claims_type_marker: PhantomData<ClaimsType>,
} }
@ -112,15 +134,20 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
algorithm: Algorithm, algorithm: Algorithm,
pool: redis_async_pool::RedisPool, pool: redis_async_pool::RedisPool,
) -> Self { ) -> Self {
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
Self { Self {
jwt_encoding_key, jwt_encoding_key,
jwt_decoding_key, jwt_decoding_key,
algorithm, algorithm,
storage: RedisStorage::new( storage: SessionStorage(storage),
pool),
_claims_type_marker: Default::default(), _claims_type_marker: Default::default(),
} }
} }
pub fn storage(&self) -> SessionStorage<ClaimsType> {
self.storage.clone()
}
} }
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType> impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType>
@ -138,8 +165,9 @@ where
ready(Ok(RedisMiddleware { ready(Ok(RedisMiddleware {
service: Rc::new(service), service: Rc::new(service),
storage: self.storage.clone(), storage: self.storage.clone(),
jwt_encoding_key: self.jwt_encoding_key.clone(),
jwt_decoding_key: self.jwt_decoding_key.clone(), jwt_decoding_key: self.jwt_decoding_key.clone(),
jwt_validator: self.jwt_validator.clone(), algorithm: self.algorithm,
_claims_type_marker: PhantomData, _claims_type_marker: PhantomData,
})) }))
} }

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage}; use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage, TokenStorage};
use actix_web::get;
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use actix_web::web::Data; use actix_web::web::{Data, Json};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use actix_web::{get, post};
use actix_web::{http::header::ContentType, test, App}; use actix_web::{http::header::ContentType, test, App};
use jsonwebtoken::*; use jsonwebtoken::*;
use ring::rand::SystemRandom; use ring::rand::SystemRandom;
@ -25,8 +25,6 @@ impl actix_jwt_session::Claims for Claims {
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn not_authenticated() { async fn not_authenticated() {
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
let validator = Validation::new(JWT_SIGNING_ALGO);
let redis = { let redis = {
use redis_async_pool::{RedisConnectionManager, RedisPool}; use redis_async_pool::{RedisConnectionManager, RedisPool};
RedisPool::new( RedisPool::new(
@ -41,12 +39,14 @@ async fn not_authenticated() {
let keys = JwtSigningKeys::generate().unwrap(); let keys = JwtSigningKeys::generate().unwrap();
let factory = RedisMiddlewareFactory::<Claims>::new( let factory = RedisMiddlewareFactory::<Claims>::new(
Arc::new(keys.encoding_key),
Arc::new(keys.decoding_key), Arc::new(keys.decoding_key),
Arc::new(validator), Algorithm::EdDSA,
Arc::new(RedisStorage::new(redis.clone())), redis.clone(),
); );
let app = App::new() let app = App::new()
.app_data(factory.storage())
.wrap(factory.clone()) .wrap(factory.clone())
.app_data(Data::new(redis.clone())) .app_data(Data::new(redis.clone()))
.service(sign_in) .service(sign_in)
@ -56,26 +56,56 @@ async fn not_authenticated() {
let app = actix_web::test::init_service(app).await; let app = actix_web::test::init_service(app).await;
let res = test::call_service(&app, test::TestRequest::default() let res = test::call_service(
.insert_header(ContentType::plaintext()) &app,
.to_request()).await; test::TestRequest::default()
.insert_header(ContentType::plaintext())
.to_request(),
)
.await;
assert!(res.status().is_success()); assert!(res.status().is_success());
let res = test::call_service(&app, test::TestRequest::default() let res = test::call_service(
.uri("/s") &app,
.insert_header(ContentType::plaintext()) test::TestRequest::default()
.to_request()).await; .uri("/s")
let s = StatusCode::UNAUTHORIZED; .insert_header(ContentType::plaintext())
assert_eq!(res.status(), s); .to_request(),
)
.await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let origina_claims = Claims { id: Uuid::new_v4() };
let res = test::call_service(
&app,
test::TestRequest::default()
.uri("/in")
.method(actix_web::http::Method::POST)
.insert_header(ContentType::json())
.set_json(&origina_claims)
.to_request(),
)
.await;
assert_eq!(res.status(), StatusCode::OK);
} }
#[get("/in")] #[post("/in")]
async fn sign_in(store: Data<RedisStorage<Claims>>) -> HttpResponse { async fn sign_in(
HttpResponse::Ok().body("") store: Data<RedisStorage<Claims>>,
claims: Json<Claims>,
) -> Result<HttpResponse, actix_web::Error> {
let claims = claims.into_inner();
let store = store.into_inner();
store
.clone()
.set_by_jti(claims, std::time::Duration::from_secs(300))
.await
.unwrap();
Ok(HttpResponse::Ok().body(""))
} }
#[get("/out")] #[post("/out")]
async fn sign_out(store: Data<RedisStorage<Claims>>) -> HttpResponse { async fn sign_out(_store: Data<RedisStorage<Claims>>) -> HttpResponse {
HttpResponse::Ok().body("") HttpResponse::Ok().body("")
} }

View File

@ -48,9 +48,11 @@ async fn main() -> std::io::Result<()> {
HttpServer::new(move || { HttpServer::new(move || {
let session_config = session_config.clone(); let session_config = session_config.clone();
let session_factory = session_config.factory();
App::new() App::new()
.wrap(middleware::Logger::default()) .wrap(middleware::Logger::default())
.wrap(session_config.factory()) .app_data(Data::new(session_factory.storage()))
.wrap(session_factory)
.app_data(Data::new(conn.clone())) .app_data(Data::new(conn.clone()))
.app_data(Data::new(redis.clone())) .app_data(Data::new(redis.clone()))
.app_data(Data::new(l10n.clone())) .app_data(Data::new(l10n.clone()))

View File

@ -1,6 +1,7 @@
use std::ops::Add; use std::ops::Add;
use std::sync::Arc; use std::sync::Arc;
use actix_jwt_session::SessionStorage;
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory}; pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpResponse}; use actix_web::{get, post, HttpResponse};
@ -13,7 +14,6 @@ use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair}; use ring::signature::{Ed25519KeyPair, KeyPair};
use sea_orm::DatabaseConnection; use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::ext::*;
use time::OffsetDateTime; use time::OffsetDateTime;
mod extract_session; mod extract_session;
@ -24,7 +24,7 @@ pub use oswilno_view::filters;
pub type UserSession = Claims; pub type UserSession = Claims;
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct JWTTtl(time::Duration); pub struct JWTTtl(std::time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -73,19 +73,15 @@ pub struct LoginResponse {
claims: Claims, claims: Claims,
} }
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
#[derive(Clone)] #[derive(Clone)]
pub struct SessionConfigurator { pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>, jwt_ttl: Data<JWTTtl>,
encoding_key: Data<EncodingKey>,
factory: RedisMiddlewareFactory<Claims>, factory: RedisMiddlewareFactory<Claims>,
} }
impl SessionConfigurator { impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) { pub fn app_data(self, config: &mut ServiceConfig) {
config config
.app_data(self.encoding_key)
.app_data(self.jwt_ttl) .app_data(self.jwt_ttl)
.service(login) .service(login)
.service(login_view) .service(login_view)
@ -130,17 +126,16 @@ impl SessionConfigurator {
} }
pub fn new(redis: redis_async_pool::RedisPool) -> Self { pub fn new(redis: redis_async_pool::RedisPool) -> Self {
let jwt_ttl = JWTTtl(31.days()); let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let validator = Validation::new(JWT_SIGNING_ALGO);
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new( let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
Arc::new(jwt_signing_keys.encoding_key),
Arc::new(jwt_signing_keys.decoding_key), Arc::new(jwt_signing_keys.decoding_key),
Arc::new(validator), Algorithm::EdDSA,
Arc::new(actix_jwt_session::RedisStorage::new(redis)), redis,
); );
Self { Self {
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
jwt_ttl: Data::new(jwt_ttl.clone()), jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory, factory: auth_middleware_factory,
} }
@ -186,18 +181,16 @@ async fn login_partial_view(t: Data<TranslationStorage>) -> SignInPartialTemplat
#[autometrics] #[autometrics]
#[post("/login")] #[post("/login")]
async fn login( async fn login(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Data<JWTTtl>, jwt_ttl: Data<JWTTtl>,
db: Data<DatabaseConnection>, db: Data<DatabaseConnection>,
redis: Data<SessionStorage<Claims>>,
payload: Form<SignInPayload>, payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>, t: Data<oswilno_view::TranslationStorage>,
lang: Lang, lang: Lang,
redis: Data<redis_async_pool::RedisPool>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let t = t.into_inner(); let t = t.into_inner();
let mut errors = Errors::default(); let mut errors = Errors::default();
match login_inner( match login_inner(
jwt_encoding_key,
jwt_ttl.into_inner(), jwt_ttl.into_inner(),
payload.into_inner(), payload.into_inner(),
db.into_inner(), db.into_inner(),
@ -221,11 +214,10 @@ async fn login(
} }
async fn login_inner( async fn login_inner(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Arc<JWTTtl>, jwt_ttl: Arc<JWTTtl>,
payload: SignInPayload, payload: SignInPayload,
db: Arc<DatabaseConnection>, db: Arc<DatabaseConnection>,
redis: Arc<redis_async_pool::RedisPool>, redis: Arc<SessionStorage<Claims>>,
errors: &mut Errors, errors: &mut Errors,
) -> Result<LoginResponse, SignInPayload> { ) -> Result<LoginResponse, SignInPayload> {
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
@ -261,42 +253,24 @@ async fn login_inner(
audience: Audience::Web, audience: Audience::Web,
jwt_id: uuid::Uuid::new_v4(), jwt_id: uuid::Uuid::new_v4(),
}; };
let jwt_token = encode( let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await {
&Header::new(JWT_SIGNING_ALGO), Err(e) => {
&jwt_claims,
&jwt_encoding_key,
)
.map_err(|_| {
errors.push_global("Bad credentials");
payload.clone()
})?;
let Ok(bin_value) = bincode::serialize(&jwt_claims) else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload.clone());
};
{
use redis::AsyncCommands;
let Ok(mut conn) = redis.get().await else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
};
if let Err(e) = conn
.set_ex::<'_, _, _, String>(
jwt_claims.jwt_id.as_bytes(),
bin_value,
jwt_ttl.0.as_seconds_f32() as usize,
)
.await
{
tracing::warn!("Failed to set sign-in claims in redis: {e}"); tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later"); errors.push_global("Failed to sign in. Please try later");
return Err(payload); return Err(payload);
} }
} Ok(jwt_token) => jwt_token,
};
let bearer_token = match jwt_token.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
};
Ok(LoginResponse { Ok(LoginResponse {
bearer_token: jwt_token, bearer_token,
claims: jwt_claims, claims: jwt_claims,
}) })
} }

View File

@ -8,7 +8,7 @@ pub mod lang;
#[derive(Debug, askama_actix::Template)] #[derive(Debug, askama_actix::Template)]
#[template(path = "../templates/base.html")] #[template(path = "../templates/base.html")]
pub struct Layout<BodyTemplate: askama::DynTemplate + std::fmt::Display> { pub struct Layout<BodyTemplate: askama::Template> {
pub main: BodyTemplate, pub main: BodyTemplate,
} }

View File

@ -11,6 +11,8 @@
</head> </head>
<body> <body>
<base url="/" /> <base url="/" />
<main>{{ main }}</main> <main>
{{ main|safe }}
</main>
</body> </body>
</html> </html>