use std::ops::Add; use std::sync::Arc; use actix_http::body::MessageBody; use actix_jwt_authc::*; use actix_web::web::{Data, ServiceConfig}; use actix_web::{get, HttpResponse}; use dashmap::DashSet; use futures::channel::{mpsc, mpsc::Sender}; use futures::stream::Stream; use futures::SinkExt; use jsonwebtoken::*; use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; use tokio::sync::Mutex; use uuid::Uuid; pub type UserSession = Authenticated; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct Claims { exp: usize, iat: usize, sub: String, } #[derive(Serialize, Deserialize)] pub struct EmptyResponse {} #[derive(Debug, Serialize, Deserialize)] pub struct LoginResponse { bearer_token: String, claims: Claims, } const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, invalidated_jwts_store: Data, encoding_key: Data, factory: AuthenticateMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(logout) .service(session_info); } pub fn factory(&self) -> AuthenticateMiddlewareFactory { self.factory.clone() } pub fn new() -> Self { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_settings = AuthenticateMiddlewareSettings { jwt_decoding_key: jwt_signing_keys.decoding_key, jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), jwt_validator: validator, jwt_session_key: None, }; let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(); let auth_middleware_factory = AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); Self { invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } } } #[get("/login")] async fn login( jwt_encoding_key: Data, jwt_ttl: Data, ) -> Result { let sub = format!("{}", Uuid::new_v4().as_u128()); let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0); let exp = expires_at.unix_timestamp() as usize; let jwt_claims = Claims { iat, exp, sub }; let jwt_token = encode( &Header::new(JWT_SIGNING_ALGO), &jwt_claims, &jwt_encoding_key, ) .map_err(|_| Error::InternalError)?; let login_response = LoginResponse { bearer_token: jwt_token, claims: jwt_claims, }; Ok(HttpResponse::Ok().json(login_response)) } #[get("/session")] async fn session_info(authenticated: UserSession) -> Result { Ok(HttpResponse::Ok().json(authenticated)) } #[get("/logout")] async fn logout( invalidated_jwts: Data, authenticated: Authenticated, ) -> Result { invalidated_jwts.add_to_invalidated(authenticated).await; Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Clone)] struct InvalidatedJWTStore { store: Arc>, tx: Arc>>, } impl InvalidatedJWTStore { /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s fn new_with_stream() -> ( InvalidatedJWTStore, impl Stream, ) { let invalidated = Arc::new(DashSet::new()); let (tx, rx) = mpsc::channel(100); let tx_to_hold = Arc::new(Mutex::new(tx)); ( InvalidatedJWTStore { store: invalidated, tx: tx_to_hold, }, rx, ) } async fn add_to_invalidated(&self, authenticated: Authenticated) { self.store.insert(authenticated.jwt.clone()); let mut tx = self.tx.lock().await; if let Err(_e) = tx .send(InvalidatedTokensEvent::Add(authenticated.jwt)) .await { #[cfg(feature = "tracing")] error!(error = ?_e, "Failed to send update on adding to invalidated") } } } struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtSigningKeys { fn generate() -> Result> { let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?; let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?; let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref()); Ok(JwtSigningKeys { encoding_key, decoding_key, }) } }