Compare commits
3 Commits
8526a45e13
...
5560f068b1
Author | SHA1 | Date | |
---|---|---|---|
5560f068b1 | |||
e34a306668 | |||
932665a767 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -186,9 +186,11 @@ dependencies = [
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-util",
|
||||
"garde",
|
||||
"jsonwebtoken",
|
||||
"redis",
|
||||
"redis-async-pool",
|
||||
"ring",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio 1.30.0",
|
||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||
[features]
|
||||
default = ['use-redis']
|
||||
use-redis = ["redis", "redis-async-pool"]
|
||||
serde-transparent = []
|
||||
|
||||
[dependencies]
|
||||
actix-web = "4"
|
||||
@ -21,3 +22,11 @@ serde = { version = "1.0.183", features = ["derive"] }
|
||||
thiserror = "1.0.44"
|
||||
tokio = { version = "1.30.0", features = ["full"] }
|
||||
uuid = { version = "1.4.1", features = ["v4"] }
|
||||
|
||||
[[test]]
|
||||
name = "ensure_redis_flow"
|
||||
path = "./tests/ensure_redis_flow.rs"
|
||||
|
||||
[dev-dependencies]
|
||||
garde = "0.14.0"
|
||||
ring = "0.16.20"
|
||||
|
@ -1,6 +1,6 @@
|
||||
use actix_web::{HttpMessage, FromRequest};
|
||||
use actix_web::{dev::ServiceRequest, HttpResponse};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use actix_web::{FromRequest, HttpMessage};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, EncodingKey, encode, Algorithm};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -40,13 +40,25 @@ impl actix_web::ResponseError for Error {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Authenticated<T>(Arc<T>);
|
||||
#[cfg_attr(feature = "serde-transparent", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "serde-transparent", serde(transparent))]
|
||||
pub struct Authenticated<T> {
|
||||
claims: Arc<T>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
algorithm: Algorithm,
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Authenticated<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self.0
|
||||
&*self.claims
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Claims> Authenticated<T> {
|
||||
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
encode(&jsonwebtoken::Header::new(self.algorithm), &*self.claims, &*self.jwt_encoding_key)
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +70,10 @@ impl<T: Claims> FromRequest for Authenticated<T> {
|
||||
req: &actix_web::HttpRequest,
|
||||
_payload: &mut actix_web::dev::Payload,
|
||||
) -> Self::Future {
|
||||
let value = req.extensions_mut().get::<Authenticated<T>>().map(Clone::clone);
|
||||
let value = req
|
||||
.extensions_mut()
|
||||
.get::<Authenticated<T>>()
|
||||
.map(Clone::clone);
|
||||
std::future::ready(value.ok_or_else(|| Error::NotFound.into()))
|
||||
}
|
||||
}
|
||||
@ -81,19 +96,22 @@ struct Extractor;
|
||||
impl Extractor {
|
||||
async fn extract_bearer_jwt<ClaimsType: Claims>(
|
||||
req: &ServiceRequest,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
jwt_validator: Arc<Validation>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
) -> Result<(), Error> {
|
||||
let authorisation_header = req
|
||||
let Some(authorisation_header) = req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.ok_or(Error::NoAuthHeader)?;
|
||||
else {
|
||||
return Ok(())
|
||||
};
|
||||
let as_str = authorisation_header
|
||||
.to_str()
|
||||
.map_err(|_| Error::NoAuthHeader)?;
|
||||
|
||||
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &*jwt_validator)
|
||||
let decoded_claims = decode::<ClaimsType>(as_str, &*jwt_decoding_key, &Validation::new(algorithm))
|
||||
.map_err(|_e| {
|
||||
// let error_message = e.to_string();
|
||||
Error::InvalidSession
|
||||
@ -109,7 +127,9 @@ impl Extractor {
|
||||
}
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(Authenticated(Arc::new(decoded_claims.claims)));
|
||||
.insert(Authenticated {
|
||||
claims: Arc::new(decoded_claims.claims),
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -61,9 +61,10 @@ where
|
||||
{
|
||||
_claims_type_marker: std::marker::PhantomData<ClaimsType>,
|
||||
service: Rc<S>,
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
jwt_validator: Arc<Validation>,
|
||||
storage: Arc<RedisStorage<ClaimsType>>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
}
|
||||
|
||||
impl<S, B, ClaimsType> Service<ServiceRequest> for RedisMiddleware<S, ClaimsType>
|
||||
@ -82,11 +83,12 @@ where
|
||||
|
||||
let svc = self.service.clone();
|
||||
let jwt_decoding_key = self.jwt_decoding_key.clone();
|
||||
let validation = self.jwt_validator.clone();
|
||||
let jwt_encoding_key = self.jwt_encoding_key.clone();
|
||||
let algorithm = self.algorithm;
|
||||
let storage = self.storage.clone();
|
||||
|
||||
async move {
|
||||
Extractor::extract_bearer_jwt(&req, jwt_decoding_key, validation, storage).await?;
|
||||
Extractor::extract_bearer_jwt(&req, jwt_encoding_key, jwt_decoding_key, algorithm, storage).await?;
|
||||
let res = svc.call(req).await?;
|
||||
Ok(res)
|
||||
}
|
||||
@ -96,22 +98,26 @@ where
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisMiddlewareFactory<ClaimsType: Claims> {
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
jwt_validator: Arc<Validation>,
|
||||
storage: Arc<RedisStorage<ClaimsType>>,
|
||||
algorithm: Algorithm,
|
||||
storage: Arc<dyn TokenStorage<ClaimsType = ClaimsType>>,
|
||||
_claims_type_marker: PhantomData<ClaimsType>,
|
||||
}
|
||||
|
||||
impl<ClaimsType: Claims> RedisMiddlewareFactory<ClaimsType> {
|
||||
pub fn new(
|
||||
jwt_encoding_key: Arc<EncodingKey>,
|
||||
jwt_decoding_key: Arc<DecodingKey>,
|
||||
jwt_validator: Arc<Validation>,
|
||||
storage: Arc<RedisStorage<ClaimsType>>,
|
||||
algorithm: Algorithm,
|
||||
pool: redis_async_pool::RedisPool,
|
||||
) -> Self {
|
||||
Self {
|
||||
jwt_encoding_key,
|
||||
jwt_decoding_key,
|
||||
jwt_validator,
|
||||
storage,
|
||||
algorithm,
|
||||
storage: RedisStorage::new(
|
||||
pool),
|
||||
_claims_type_marker: Default::default(),
|
||||
}
|
||||
}
|
||||
|
108
crates/actix-jwt-session/tests/ensure_redis_flow.rs
Normal file
108
crates/actix-jwt-session/tests/ensure_redis_flow.rs
Normal file
@ -0,0 +1,108 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_jwt_session::{Authenticated, RedisMiddlewareFactory, RedisStorage};
|
||||
use actix_web::get;
|
||||
use actix_web::http::StatusCode;
|
||||
use actix_web::web::Data;
|
||||
use actix_web::HttpResponse;
|
||||
use actix_web::{http::header::ContentType, test, App};
|
||||
use jsonwebtoken::*;
|
||||
use ring::rand::SystemRandom;
|
||||
use ring::signature::{Ed25519KeyPair, KeyPair};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
struct Claims {
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl actix_jwt_session::Claims for Claims {
|
||||
fn jti(&self) -> Uuid {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn not_authenticated() {
|
||||
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
|
||||
let validator = Validation::new(JWT_SIGNING_ALGO);
|
||||
let redis = {
|
||||
use redis_async_pool::{RedisConnectionManager, RedisPool};
|
||||
RedisPool::new(
|
||||
RedisConnectionManager::new(
|
||||
redis::Client::open("redis://localhost:6379").expect("Fail to connect to redis"),
|
||||
true,
|
||||
None,
|
||||
),
|
||||
5,
|
||||
)
|
||||
};
|
||||
|
||||
let keys = JwtSigningKeys::generate().unwrap();
|
||||
let factory = RedisMiddlewareFactory::<Claims>::new(
|
||||
Arc::new(keys.decoding_key),
|
||||
Arc::new(validator),
|
||||
Arc::new(RedisStorage::new(redis.clone())),
|
||||
);
|
||||
|
||||
let app = App::new()
|
||||
.wrap(factory.clone())
|
||||
.app_data(Data::new(redis.clone()))
|
||||
.service(sign_in)
|
||||
.service(sign_out)
|
||||
.service(session)
|
||||
.service(root);
|
||||
|
||||
let app = actix_web::test::init_service(app).await;
|
||||
|
||||
let res = test::call_service(&app, test::TestRequest::default()
|
||||
.insert_header(ContentType::plaintext())
|
||||
.to_request()).await;
|
||||
assert!(res.status().is_success());
|
||||
|
||||
let res = test::call_service(&app, test::TestRequest::default()
|
||||
.uri("/s")
|
||||
.insert_header(ContentType::plaintext())
|
||||
.to_request()).await;
|
||||
let s = StatusCode::UNAUTHORIZED;
|
||||
assert_eq!(res.status(), s);
|
||||
}
|
||||
|
||||
#[get("/in")]
|
||||
async fn sign_in(store: Data<RedisStorage<Claims>>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("")
|
||||
}
|
||||
|
||||
#[get("/out")]
|
||||
async fn sign_out(store: Data<RedisStorage<Claims>>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("")
|
||||
}
|
||||
|
||||
#[get("/s")]
|
||||
async fn session(auth: Authenticated<Claims>) -> HttpResponse {
|
||||
HttpResponse::Ok().json(&*auth)
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
async fn root() -> HttpResponse {
|
||||
HttpResponse::Ok().body("")
|
||||
}
|
||||
|
||||
pub struct JwtSigningKeys {
|
||||
encoding_key: EncodingKey,
|
||||
decoding_key: DecodingKey,
|
||||
}
|
||||
|
||||
impl JwtSigningKeys {
|
||||
fn generate() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
|
||||
let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
|
||||
let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
|
||||
let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
|
||||
Ok(JwtSigningKeys {
|
||||
encoding_key,
|
||||
decoding_key,
|
||||
})
|
||||
}
|
||||
}
|
@ -9,7 +9,7 @@ use sea_orm::ActiveValue::{NotSet, Set};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use oswilno_session::{Claims, Authenticated};
|
||||
use oswilno_session::{Authenticated, Claims};
|
||||
use oswilno_view::Layout;
|
||||
|
||||
pub fn mount(config: &mut ServiceConfig) {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use actix_jwt_session::{Error, RedisMiddlewareFactory, Authenticated};
|
||||
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
|
||||
use actix_web::web::{Data, Form, ServiceConfig};
|
||||
use actix_web::{get, post, HttpResponse};
|
||||
use askama_actix::Template;
|
||||
|
Loading…
Reference in New Issue
Block a user