Using session
This commit is contained in:
parent
5560f068b1
commit
f265d22b87
@ -1,6 +1,6 @@
|
||||
use actix_web::{dev::ServiceRequest, HttpResponse};
|
||||
use actix_web::{FromRequest, HttpMessage};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm};
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -58,7 +58,11 @@ impl<T> std::ops::Deref for Authenticated<T> {
|
||||
|
||||
impl<T: Claims> Authenticated<T> {
|
||||
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key)
|
||||
encode(
|
||||
&jsonwebtoken::Header::new(self.algorithm),
|
||||
&*self.claims,
|
||||
&*self.jwt_encoding_key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,7 +83,7 @@ impl<T: Claims> FromRequest for Authenticated<T> {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait(?Send)]
|
||||
pub trait TokenStorage {
|
||||
pub trait TokenStorage: Send + Sync {
|
||||
type ClaimsType: Claims;
|
||||
|
||||
async fn get_from_jti(self: Arc<Self>, jti: uuid::Uuid) -> Result<Self::ClaimsType, Error>;
|
||||
@ -89,6 +93,48 @@ pub trait TokenStorage {
|
||||
claims: Self::ClaimsType,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error>;
|
||||
|
||||
fn jwt_encoding_key(&self) -> Arc<EncodingKey>;
|
||||
|
||||
fn algorithm(&self) -> Algorithm;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStorage<ClaimsType: Claims>(Arc<dyn TokenStorage<ClaimsType = ClaimsType>>);
|
||||
|
||||
impl<ClaimsType: Claims> std::ops::Deref for SessionStorage<ClaimsType> {
|
||||
type Target = Arc<dyn TokenStorage<ClaimsType = ClaimsType>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> SessionStorage<ClaimsType> {
|
||||
pub async fn set_by_jti(
|
||||
&self,
|
||||
claims: ClaimsType,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<(), Error> {
|
||||
self.0.clone().set_by_jti(claims, exp).await
|
||||
}
|
||||
|
||||
pub async fn get_from_jti(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
|
||||
self.0.clone().get_from_jti(jti).await
|
||||
}
|
||||
|
||||
pub async fn store(
|
||||
&self,
|
||||
claims: ClaimsType,
|
||||
exp: std::time::Duration,
|
||||
) -> Result<Authenticated<ClaimsType>, Error> {
|
||||
self.set_by_jti(claims.clone(), exp).await?;
|
||||
Ok(Authenticated {
|
||||
claims: Arc::new(claims),
|
||||
jwt_encoding_key: self.0.jwt_encoding_key(),
|
||||
algorithm: self.algorithm(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Extractor;
|
||||
@ -99,7 +145,7 @@ impl Extractor {
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
) -> Result<(), Error> {
|
||||
let Some(authorisation_header) = req
|
||||
.headers()
|
||||
@ -111,13 +157,17 @@ impl Extractor {
|
||||
.to_str()
|
||||
.map_err(|_| Error::NoAuthHeader)?;
|
||||
|
||||
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm))
|
||||
.map_err(|_e| {
|
||||
let decoded_claims =
|
||||
decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm)).map_err(
|
||||
|_e| {
|
||||
// let error_message = e.to_string();
|
||||
Error::InvalidSession
|
||||
})?;
|
||||
},
|
||||
)?;
|
||||
|
||||
let stored = storage
|
||||
.0
|
||||
.clone()
|
||||
.get_from_jti(decoded_claims.claims.jti())
|
||||
.await
|
||||
.map_err(|_| Error::InvalidSession)?;
|
||||
@ -126,9 +176,10 @@ impl Extractor {
|
||||
return Err(Error::InvalidSession);
|
||||
}
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(Authenticated {
|
||||
req.extensions_mut().insert(Authenticated {
|
||||
claims: Arc::new(decoded_claims.claims),
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
use super::*;
|
||||
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
||||
use futures_util::future::LocalBoxFuture;
|
||||
use jsonwebtoken::{DecodingKey, Validation};
|
||||
use redis::AsyncCommands;
|
||||
use std::future::{ready, Ready};
|
||||
use std::marker::PhantomData;
|
||||
@ -9,15 +8,23 @@ use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisStorage<ClaimsType: Claims> {
|
||||
struct RedisStorage<ClaimsType: Claims> {
|
||||
pool: redis_async_pool::RedisPool,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
algorithm: Algorithm,
|
||||
_claims_type_marker: PhantomData<ClaimsType>,
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> RedisStorage<ClaimsType> {
|
||||
pub fn new(pool: redis_async_pool::RedisPool) -> Self {
|
||||
pub fn new(
|
||||
pool: redis_async_pool::RedisPool,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
algorithm: Algorithm,
|
||||
) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
jwt_encoding_key,
|
||||
algorithm,
|
||||
_claims_type_marker: Default::default(),
|
||||
}
|
||||
}
|
||||
@ -53,6 +60,14 @@ where
|
||||
.map_err(|_| Error::WriteFailed)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn jwt_encoding_key(&self) -> Arc<EncodingKey> {
|
||||
self.jwt_encoding_key.clone()
|
||||
}
|
||||
|
||||
fn algorithm(&self) -> Algorithm {
|
||||
self.algorithm
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RedisMiddleware<S, ClaimsType>
|
||||
@ -64,7 +79,7 @@ where
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
}
|
||||
|
||||
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
|
||||
@ -88,7 +103,14 @@ where
|
||||
let storage = self.storage.clone();
|
||||
|
||||
async move {
|
||||
Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?;
|
||||
Extractor::extract_bearer_jwt(
|
||||
&req,
|
||||
jwt_encoding_key,
|
||||
jwt_decoding_key,
|
||||
algorithm,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
let res = svc.call(req).await?;
|
||||
Ok(res)
|
||||
}
|
||||
@ -101,7 +123,7 @@ pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
storage: SessionStorage<ClaimsType>,
|
||||
_claims_type_marker: PhantomData<ClaimsType>,
|
||||
}
|
||||
|
||||
@ -112,15 +134,20 @@ impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
||||
algorithm: Algorithm,
|
||||
pool: redis_async_pool::RedisPool,
|
||||
) -> Self {
|
||||
let storage = Arc::new(RedisStorage::new(pool, jwt_encoding_key.clone(), algorithm));
|
||||
|
||||
Self {
|
||||
jwt_encoding_key,
|
||||
jwt_decoding_key,
|
||||
algorithm,
|
||||
storage: RedisStorage::new(
|
||||
pool),
|
||||
storage: SessionStorage(storage),
|
||||
_claims_type_marker: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> SessionStorage<ClaimsType> {
|
||||
self.storage.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for RedisMiddlewareFactory<ClaimsType>
|
||||
@ -138,8 +165,9 @@ where
|
||||
ready(Ok(RedisMiddleware {
|
||||
service: Rc::new(service),
|
||||
storage: self.storage.clone(),
|
||||
jwt_encoding_key: self.jwt_encoding_key.clone(),
|
||||
jwt_decoding_key: self.jwt_decoding_key.clone(),
|
||||
jwt_validator: self.jwt_validator.clone(),
|
||||
algorithm: self.algorithm,
|
||||
_claims_type_marker: PhantomData,
|
||||
}))
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage};
|
||||
use actix_web::get;
|
||||
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage, TokenStorage};
|
||||
use actix_web::http::StatusCode;
|
||||
use actix_web::web::Data;
|
||||
use actix_web::web::{Data, Json};
|
||||
use actix_web::HttpResponse;
|
||||
use actix_web::{get, post};
|
||||
use actix_web::{http::header::ContentType, test, App};
|
||||
use jsonwebtoken::*;
|
||||
use ring::rand::SystemRandom;
|
||||
@ -25,8 +25,6 @@ impl actix_jwt_session::Claims for Claims {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn not_authenticated() {
|
||||
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
|
||||
let validator = Validation::new(JWT_SIGNING_ALGO);
|
||||
let redis = {
|
||||
use redis_async_pool::{RedisConnectionManager, RedisPool};
|
||||
RedisPool::new(
|
||||
@ -41,12 +39,14 @@ async fn not_authenticated() {
|
||||
|
||||
let keys = JwtSigningKeys::generate().unwrap();
|
||||
let factory = RedisMiddlewareFactory::<Claims>::new(
|
||||
Arc::new(keys.encoding_key),
|
||||
Arc::new(keys.decoding_key),
|
||||
Arc::new(validator),
|
||||
Arc::new(RedisStorage::new(redis.clone())),
|
||||
Algorithm::EdDSA,
|
||||
redis.clone(),
|
||||
);
|
||||
|
||||
let app = App::new()
|
||||
.app_data(factory.storage())
|
||||
.wrap(factory.clone())
|
||||
.app_data(Data::new(redis.clone()))
|
||||
.service(sign_in)
|
||||
@ -56,26 +56,56 @@ async fn not_authenticated() {
|
||||
|
||||
let app = actix_web::test::init_service(app).await;
|
||||
|
||||
let res = test::call_service(&app, test::TestRequest::default()
|
||||
let res = test::call_service(
|
||||
&app,
|
||||
test::TestRequest::default()
|
||||
.insert_header(ContentType::plaintext())
|
||||
.to_request()).await;
|
||||
.to_request(),
|
||||
)
|
||||
.await;
|
||||
assert!(res.status().is_success());
|
||||
|
||||
let res = test::call_service(&app, test::TestRequest::default()
|
||||
let res = test::call_service(
|
||||
&app,
|
||||
test::TestRequest::default()
|
||||
.uri("/s")
|
||||
.insert_header(ContentType::plaintext())
|
||||
.to_request()).await;
|
||||
let s = StatusCode::UNAUTHORIZED;
|
||||
assert_eq!(res.status(), s);
|
||||
.to_request(),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
|
||||
let origina_claims = Claims { id: Uuid::new_v4() };
|
||||
let res = test::call_service(
|
||||
&app,
|
||||
test::TestRequest::default()
|
||||
.uri("/in")
|
||||
.method(actix_web::http::Method::POST)
|
||||
.insert_header(ContentType::json())
|
||||
.set_json(&origina_claims)
|
||||
.to_request(),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[get("/in")]
|
||||
async fn sign_in(store: Data<RedisStorage<Claims>>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("")
|
||||
#[post("/in")]
|
||||
async fn sign_in(
|
||||
store: Data<RedisStorage<Claims>>,
|
||||
claims: Json<Claims>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let claims = claims.into_inner();
|
||||
let store = store.into_inner();
|
||||
store
|
||||
.clone()
|
||||
.set_by_jti(claims, std::time::Duration::from_secs(300))
|
||||
.await
|
||||
.unwrap();
|
||||
Ok(HttpResponse::Ok().body(""))
|
||||
}
|
||||
|
||||
#[get("/out")]
|
||||
async fn sign_out(store: Data<RedisStorage<Claims>>) -> HttpResponse {
|
||||
#[post("/out")]
|
||||
async fn sign_out(_store: Data<RedisStorage<Claims>>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("")
|
||||
}
|
||||
|
||||
|
@ -48,9 +48,11 @@ async fn main() -> std::io::Result<()> {
|
||||
|
||||
HttpServer::new(move || {
|
||||
let session_config = session_config.clone();
|
||||
let session_factory = session_config.factory();
|
||||
App::new()
|
||||
.wrap(middleware::Logger::default())
|
||||
.wrap(session_config.factory())
|
||||
.app_data(Data::new(session_factory.storage()))
|
||||
.wrap(session_factory)
|
||||
.app_data(Data::new(conn.clone()))
|
||||
.app_data(Data::new(redis.clone()))
|
||||
.app_data(Data::new(l10n.clone()))
|
||||
|
@ -1,6 +1,7 @@
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_jwt_session::SessionStorage;
|
||||
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
|
||||
use actix_web::web::{Data, Form, ServiceConfig};
|
||||
use actix_web::{get, post, HttpResponse};
|
||||
@ -13,7 +14,6 @@ use ring::rand::SystemRandom;
|
||||
use ring::signature::{Ed25519KeyPair, KeyPair};
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::ext::*;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
mod extract_session;
|
||||
@ -24,7 +24,7 @@ pub use oswilno_view::filters;
|
||||
pub type UserSession = Claims;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct JWTTtl(time::Duration);
|
||||
pub struct JWTTtl(std::time::Duration);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@ -73,19 +73,15 @@ pub struct LoginResponse {
|
||||
claims: Claims,
|
||||
}
|
||||
|
||||
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionConfigurator {
|
||||
jwt_ttl: Data<JWTTtl>,
|
||||
encoding_key: Data<EncodingKey>,
|
||||
factory: RedisMiddlewareFactory<Claims>,
|
||||
}
|
||||
|
||||
impl SessionConfigurator {
|
||||
pub fn app_data(self, config: &mut ServiceConfig) {
|
||||
config
|
||||
.app_data(self.encoding_key)
|
||||
.app_data(self.jwt_ttl)
|
||||
.service(login)
|
||||
.service(login_view)
|
||||
@ -130,17 +126,16 @@ impl SessionConfigurator {
|
||||
}
|
||||
|
||||
pub fn new(redis: redis_async_pool::RedisPool) -> Self {
|
||||
let jwt_ttl = JWTTtl(31.days());
|
||||
let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
|
||||
let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
|
||||
let validator = Validation::new(JWT_SIGNING_ALGO);
|
||||
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
|
||||
Arc::new(jwt_signing_keys.encoding_key),
|
||||
Arc::new(jwt_signing_keys.decoding_key),
|
||||
Arc::new(validator),
|
||||
Arc::new(actix_jwt_session::RedisStorage::new(redis)),
|
||||
Algorithm::EdDSA,
|
||||
redis,
|
||||
);
|
||||
|
||||
Self {
|
||||
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
|
||||
jwt_ttl: Data::new(jwt_ttl.clone()),
|
||||
factory: auth_middleware_factory,
|
||||
}
|
||||
@ -186,18 +181,16 @@ async fn login_partial_view(t: Data<TranslationStorage>) -> SignInPartialTemplat
|
||||
#[autometrics]
|
||||
#[post("/login")]
|
||||
async fn login(
|
||||
jwt_encoding_key: Data<EncodingKey>,
|
||||
jwt_ttl: Data<JWTTtl>,
|
||||
db: Data<DatabaseConnection>,
|
||||
redis: Data<SessionStorage<Claims>>,
|
||||
payload: Form<SignInPayload>,
|
||||
t: Data<oswilno_view::TranslationStorage>,
|
||||
lang: Lang,
|
||||
redis: Data<redis_async_pool::RedisPool>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let t = t.into_inner();
|
||||
let mut errors = Errors::default();
|
||||
match login_inner(
|
||||
jwt_encoding_key,
|
||||
jwt_ttl.into_inner(),
|
||||
payload.into_inner(),
|
||||
db.into_inner(),
|
||||
@ -221,11 +214,10 @@ async fn login(
|
||||
}
|
||||
|
||||
async fn login_inner(
|
||||
jwt_encoding_key: Data<EncodingKey>,
|
||||
jwt_ttl: Arc<JWTTtl>,
|
||||
payload: SignInPayload,
|
||||
db: Arc<DatabaseConnection>,
|
||||
redis: Arc<redis_async_pool::RedisPool>,
|
||||
redis: Arc<SessionStorage<Claims>>,
|
||||
errors: &mut Errors,
|
||||
) -> Result<LoginResponse, SignInPayload> {
|
||||
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
|
||||
@ -261,42 +253,24 @@ async fn login_inner(
|
||||
audience: Audience::Web,
|
||||
jwt_id: uuid::Uuid::new_v4(),
|
||||
};
|
||||
let jwt_token = encode(
|
||||
&Header::new(JWT_SIGNING_ALGO),
|
||||
&jwt_claims,
|
||||
&jwt_encoding_key,
|
||||
)
|
||||
.map_err(|_| {
|
||||
errors.push_global("Bad credentials");
|
||||
payload.clone()
|
||||
})?;
|
||||
let Ok(bin_value) = bincode::serialize(&jwt_claims) else {
|
||||
errors.push_global("Failed to sign in. Please try later");
|
||||
return Err(payload.clone());
|
||||
};
|
||||
{
|
||||
use redis::AsyncCommands;
|
||||
|
||||
let Ok(mut conn) = redis.get().await else {
|
||||
errors.push_global("Failed to sign in. Please try later");
|
||||
return Err(payload);
|
||||
};
|
||||
|
||||
if let Err(e) = conn
|
||||
.set_ex::<'_, _, _, String>(
|
||||
jwt_claims.jwt_id.as_bytes(),
|
||||
bin_value,
|
||||
jwt_ttl.0.as_seconds_f32() as usize,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await {
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to set sign-in claims in redis: {e}");
|
||||
errors.push_global("Failed to sign in. Please try later");
|
||||
return Err(payload);
|
||||
}
|
||||
Ok(jwt_token) => jwt_token,
|
||||
};
|
||||
let bearer_token = match jwt_token.encode() {
|
||||
Ok(token) => token,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to set sign-in claims in redis: {e}");
|
||||
errors.push_global("Failed to sign in. Please try later");
|
||||
return Err(payload);
|
||||
}
|
||||
};
|
||||
Ok(LoginResponse {
|
||||
bearer_token: jwt_token,
|
||||
bearer_token,
|
||||
claims: jwt_claims,
|
||||
})
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ pub mod lang;
|
||||
|
||||
#[derive(Debug, askama_actix::Template)]
|
||||
#[template(path = "../templates/base.html")]
|
||||
pub struct Layout<BodyTemplate: askama::DynTemplate + std::fmt::Display> {
|
||||
pub struct Layout<BodyTemplate: askama::Template> {
|
||||
pub main: BodyTemplate,
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
</head>
|
||||
<body>
|
||||
<base url="/" />
|
||||
<main>{{ main }}</main>
|
||||
<main>
|
||||
{{ main|safe }}
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
|
Loading…
Reference in New Issue
Block a user