Using session

This commit is contained in:
Adrian Woźniak 2023-08-14 12:30:32 +02:00
parent 5560f068b1
commit f265d22b87
7 changed files with 186 additions and 99 deletions

View File

@ -1,6 +1,6 @@
use actix_web::{dev::ServiceRequest, HttpResponse};
use actix_web::{FromRequest, HttpMessage};
use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
@ -58,7 +58,11 @@ impl<T> std::ops::Deref for Authenticated<T> {
impl<T: Claims> Authenticated<T> {
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key)
encode(
&jsonwebtoken::Header::new(self.algorithm),
&*self.claims,
&*self.jwt_encoding_key,
)
}
}
@ -79,7 +83,7 @@ impl<T: Claims> FromRequest for Authenticated<T> {
}
#[async_trait::async_trait(?Send)]
pub trait TokenStorage {
pub trait TokenStorage: Send + Sync {
type ClaimsType: Claims;
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
@ -89,6 +93,48 @@ pub trait TokenStorage {
claims: Self::ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error>;
fn jwt_encoding_key(&self) -> Arc<EncodingKey>;
fn algorithm(&self) -> Algorithm;
}
#[derive(Clone)]
pub struct SessionStorage<ClaimsType: Claims>(Arc<dyn TokenStorage<ClaimsType = ClaimsType>>);
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
pub async fn set_by_jti(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<(), Error> {
self.0.clone().set_by_jti(claims, exp).await
}
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
self.0.clone().get_from_jti(jti).await
}
pub async fn store(
&self,
claims: ClaimsType,
exp: std::time::Duration,
) -> Result<Authenticated<ClaimsType>, Error> {
self.set_by_jti(claims.clone(), exp).await?;
Ok(Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.0.jwt_encoding_key(),
algorithm: self.algorithm(),
})
}
}
struct Extractor;
@ -99,7 +145,7 @@ impl Extractor {
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
storage: SessionStorage<ClaimsType>,
) -> Result<(), Error> {
let Some(authorisation_header) = req
.headers()
@ -111,13 +157,17 @@ impl Extractor {
.to_str()
.map_err(|_| Error::NoAuthHeader)?;
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm))
.map_err(|_e| {
let decoded_claims =
decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|_e| {
// let error_message = e.to_string();
Error::InvalidSession
})?;
},
)?;
let stored = storage
.0
.clone()
.get_from_jti(decoded_claims.claims.jti())
.await
.map_err(|_| Error::InvalidSession)?;
@ -126,9 +176,10 @@ impl Extractor {
return Err(Error::InvalidSession);
}
req.extensions_mut()
.insert(Authenticated {
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims.claims),
jwt_encoding_key,
algorithm,
});
Ok(())
}

View File

@ -1,7 +1,6 @@
use super::*;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{DecodingKey, Validation};
use redis::AsyncCommands;
use std::future::{ready, Ready};
use std::marker::PhantomData;
@ -9,15 +8,23 @@ use std::rc::Rc;
use std::sync::Arc;
#[derive(Clone)]
pub struct RedisStorage<ClaimsType: Claims> {
struct RedisStorage<ClaimsType: Claims> {
pool: redis_async_pool::RedisPool,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
_claims_type_marker: PhantomData<ClaimsType>,
}
impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
pub fn new(pool: redis_async_pool::RedisPool) -> Self {
pub fn new(
pool: redis_async_pool::RedisPool,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
) -> Self {
Self {
pool,
jwt_encoding_key,
algorithm,
_claims_type_marker: Default::default(),
}
}
@ -53,6 +60,14 @@ where
.map_err(|_| Error::WriteFailed)?;
Ok(())
}
fn jwt_encoding_key(&self) -> Arc<EncodingKey> {
self.jwt_encoding_key.clone()
}
fn algorithm(&self) -> Algorithm {
self.algorithm
}
}
pub struct RedisMiddleware<S, ClaimsType>
@ -64,7 +79,7 @@ where
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
storage: SessionStorage<ClaimsType>,
}
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
@ -88,7 +103,14 @@ where
let storage = self.storage.clone();
async move {
Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?;
Extractor::extract_bearer_jwt(
&req,
jwt_encoding_key,
jwt_decoding_key,
algorithm,
storage,
)
.await?;
let res = svc.call(req).await?;
Ok(res)
}
@ -101,7 +123,7 @@ pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
storage: SessionStorage<ClaimsType>,
_claims_type_marker: PhantomData<ClaimsType>,
}
@ -112,15 +134,20 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
algorithm: Algorithm,
pool: redis_async_pool::RedisPool,
) -> Self {
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
Self {
jwt_encoding_key,
jwt_decoding_key,
algorithm,
storage: RedisStorage::new(
pool),
storage: SessionStorage(storage),
_claims_type_marker: Default::default(),
}
}
pub fn storage(&self) -> SessionStorage<ClaimsType> {
self.storage.clone()
}
}
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType>
@ -138,8 +165,9 @@ where
ready(Ok(RedisMiddleware {
service: Rc::new(service),
storage: self.storage.clone(),
jwt_encoding_key: self.jwt_encoding_key.clone(),
jwt_decoding_key: self.jwt_decoding_key.clone(),
jwt_validator: self.jwt_validator.clone(),
algorithm: self.algorithm,
_claims_type_marker: PhantomData,
}))
}

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage};
use actix_web::get;
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage, TokenStorage};
use actix_web::http::StatusCode;
use actix_web::web::Data;
use actix_web::web::{Data, Json};
use actix_web::HttpResponse;
use actix_web::{get, post};
use actix_web::{http::header::ContentType, test, App};
use jsonwebtoken::*;
use ring::rand::SystemRandom;
@ -25,8 +25,6 @@ impl actix_jwt_session::Claims for Claims {
#[tokio::test(flavor = "multi_thread")]
async fn not_authenticated() {
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
let validator = Validation::new(JWT_SIGNING_ALGO);
let redis = {
use redis_async_pool::{RedisConnectionManager, RedisPool};
RedisPool::new(
@ -41,12 +39,14 @@ async fn not_authenticated() {
let keys = JwtSigningKeys::generate().unwrap();
let factory = RedisMiddlewareFactory::<Claims>::new(
Arc::new(keys.encoding_key),
Arc::new(keys.decoding_key),
Arc::new(validator),
Arc::new(RedisStorage::new(redis.clone())),
Algorithm::EdDSA,
redis.clone(),
);
let app = App::new()
.app_data(factory.storage())
.wrap(factory.clone())
.app_data(Data::new(redis.clone()))
.service(sign_in)
@ -56,26 +56,56 @@ async fn not_authenticated() {
let app = actix_web::test::init_service(app).await;
let res = test::call_service(&app, test::TestRequest::default()
let res = test::call_service(
&app,
test::TestRequest::default()
.insert_header(ContentType::plaintext())
.to_request()).await;
.to_request(),
)
.await;
assert!(res.status().is_success());
let res = test::call_service(&app, test::TestRequest::default()
let res = test::call_service(
&app,
test::TestRequest::default()
.uri("/s")
.insert_header(ContentType::plaintext())
.to_request()).await;
let s = StatusCode::UNAUTHORIZED;
assert_eq!(res.status(), s);
.to_request(),
)
.await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let origina_claims = Claims { id: Uuid::new_v4() };
let res = test::call_service(
&app,
test::TestRequest::default()
.uri("/in")
.method(actix_web::http::Method::POST)
.insert_header(ContentType::json())
.set_json(&origina_claims)
.to_request(),
)
.await;
assert_eq!(res.status(), StatusCode::OK);
}
#[get("/in")]
async fn sign_in(store: Data<RedisStorage<Claims>>) -> HttpResponse {
HttpResponse::Ok().body("")
#[post("/in")]
async fn sign_in(
store: Data<RedisStorage<Claims>>,
claims: Json<Claims>,
) -> Result<HttpResponse, actix_web::Error> {
let claims = claims.into_inner();
let store = store.into_inner();
store
.clone()
.set_by_jti(claims, std::time::Duration::from_secs(300))
.await
.unwrap();
Ok(HttpResponse::Ok().body(""))
}
#[get("/out")]
async fn sign_out(store: Data<RedisStorage<Claims>>) -> HttpResponse {
#[post("/out")]
async fn sign_out(_store: Data<RedisStorage<Claims>>) -> HttpResponse {
HttpResponse::Ok().body("")
}

View File

@ -48,9 +48,11 @@ async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let session_config = session_config.clone();
let session_factory = session_config.factory();
App::new()
.wrap(middleware::Logger::default())
.wrap(session_config.factory())
.app_data(Data::new(session_factory.storage()))
.wrap(session_factory)
.app_data(Data::new(conn.clone()))
.app_data(Data::new(redis.clone()))
.app_data(Data::new(l10n.clone()))

View File

@ -1,6 +1,7 @@
use std::ops::Add;
use std::sync::Arc;
use actix_jwt_session::SessionStorage;
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpResponse};
@ -13,7 +14,6 @@ use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use time::ext::*;
use time::OffsetDateTime;
mod extract_session;
@ -24,7 +24,7 @@ pub use oswilno_view::filters;
pub type UserSession = Claims;
#[derive(Clone, Copy)]
pub struct JWTTtl(time::Duration);
pub struct JWTTtl(std::time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
@ -73,19 +73,15 @@ pub struct LoginResponse {
claims: Claims,
}
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
#[derive(Clone)]
pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>,
encoding_key: Data<EncodingKey>,
factory: RedisMiddlewareFactory<Claims>,
}
impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) {
config
.app_data(self.encoding_key)
.app_data(self.jwt_ttl)
.service(login)
.service(login_view)
@ -130,17 +126,16 @@ impl SessionConfigurator {
}
pub fn new(redis: redis_async_pool::RedisPool) -> Self {
let jwt_ttl = JWTTtl(31.days());
let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let validator = Validation::new(JWT_SIGNING_ALGO);
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
Arc::new(jwt_signing_keys.encoding_key),
Arc::new(jwt_signing_keys.decoding_key),
Arc::new(validator),
Arc::new(actix_jwt_session::RedisStorage::new(redis)),
Algorithm::EdDSA,
redis,
);
Self {
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory,
}
@ -186,18 +181,16 @@ async fn login_partial_view(t: Data<TranslationStorage>) -> SignInPartialTemplat
#[autometrics]
#[post("/login")]
async fn login(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Data<JWTTtl>,
db: Data<DatabaseConnection>,
redis: Data<SessionStorage<Claims>>,
payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>,
lang: Lang,
redis: Data<redis_async_pool::RedisPool>,
) -> Result<HttpResponse, Error> {
let t = t.into_inner();
let mut errors = Errors::default();
match login_inner(
jwt_encoding_key,
jwt_ttl.into_inner(),
payload.into_inner(),
db.into_inner(),
@ -221,11 +214,10 @@ async fn login(
}
async fn login_inner(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Arc<JWTTtl>,
payload: SignInPayload,
db: Arc<DatabaseConnection>,
redis: Arc<redis_async_pool::RedisPool>,
redis: Arc<SessionStorage<Claims>>,
errors: &mut Errors,
) -> Result<LoginResponse, SignInPayload> {
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
@ -261,42 +253,24 @@ async fn login_inner(
audience: Audience::Web,
jwt_id: uuid::Uuid::new_v4(),
};
let jwt_token = encode(
&Header::new(JWT_SIGNING_ALGO),
&jwt_claims,
&jwt_encoding_key,
)
.map_err(|_| {
errors.push_global("Bad credentials");
payload.clone()
})?;
let Ok(bin_value) = bincode::serialize(&jwt_claims) else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload.clone());
};
{
use redis::AsyncCommands;
let Ok(mut conn) = redis.get().await else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
};
if let Err(e) = conn
.set_ex::<'_, _, _, String>(
jwt_claims.jwt_id.as_bytes(),
bin_value,
jwt_ttl.0.as_seconds_f32() as usize,
)
.await
{
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await {
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
Ok(jwt_token) => jwt_token,
};
let bearer_token = match jwt_token.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
};
Ok(LoginResponse {
bearer_token: jwt_token,
bearer_token,
claims: jwt_claims,
})
}

View File

@ -8,7 +8,7 @@ pub mod lang;
#[derive(Debug, askama_actix::Template)]
#[template(path = "../templates/base.html")]
pub struct Layout<BodyTemplate: askama::DynTemplate + std::fmt::Display> {
pub struct Layout<BodyTemplate: askama::Template> {
pub main: BodyTemplate,
}

View File

@ -11,6 +11,8 @@
</head>
<body>
<base url="/" />
<main>{{ main }}</main>
<main>
{{ main|safe }}
</main>
</body>
</html>