use std::ops::Add; use std::sync::Arc; use actix_jwt_authc::*; use actix_web::web::{Data, ServiceConfig, Json}; use actix_web::{get, post, HttpResponse}; use dashmap::DashSet; use futures::channel::{mpsc, mpsc::Sender}; use futures::stream::Stream; use futures::SinkExt; use jsonwebtoken::*; use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use sea_orm::{DatabaseConnection, EntityTrait}; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; use tokio::sync::Mutex; use uuid::Uuid; pub type UserSession = Authenticated; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct Claims { exp: usize, iat: usize, sub: String, } #[derive(Serialize, Deserialize)] pub struct EmptyResponse {} #[derive(Debug, Serialize, Deserialize)] pub struct LoginResponse { bearer_token: String, claims: Claims, } const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, invalidated_jwts_store: Data, encoding_key: Data, factory: AuthenticateMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(logout) .service(session_info); } pub fn factory(&self) -> AuthenticateMiddlewareFactory { self.factory.clone() } pub fn new() -> Self { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_settings = AuthenticateMiddlewareSettings { jwt_decoding_key: jwt_signing_keys.decoding_key, jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), jwt_validator: validator, jwt_session_key: None, }; let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(); let auth_middleware_factory = AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); Self { invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } } } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct SignInPayload { login: String, password: String, } #[post("/login")] async fn login( jwt_encoding_key: Data, jwt_ttl: Data, db: Data, payload: Json, ) -> Result { let db = db.into_inner(); let payload = payload.into_inner(); let sub = format!("{}", Uuid::new_v4().as_u128()); let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0); let exp = expires_at.unix_timestamp() as usize; use sea_orm::*; let account = match oswilno_contract::accounts::Entity::find().filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str())).one(&*db).await { Ok(a) => a, Err(e) => { tracing::warn!("Failed to find account: {e}"); return Err(Error::InternalError) } }; let jwt_claims = Claims { iat, exp, sub }; let jwt_token = encode( &Header::new(JWT_SIGNING_ALGO), &jwt_claims, &jwt_encoding_key, ) .map_err(|_| Error::InternalError)?; let login_response = LoginResponse { bearer_token: jwt_token, claims: jwt_claims, }; Ok(HttpResponse::Ok().json(login_response)) } #[get("/session")] async fn session_info(authenticated: UserSession) -> Result { Ok(HttpResponse::Ok().json(authenticated)) } #[get("/logout")] async fn logout( invalidated_jwts: Data, authenticated: Authenticated, ) -> Result { invalidated_jwts.add_to_invalidated(authenticated).await; Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Clone)] struct InvalidatedJWTStore { store: Arc>, tx: Arc>>, } impl InvalidatedJWTStore { /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s fn new_with_stream() -> ( InvalidatedJWTStore, impl Stream, ) { let invalidated = Arc::new(DashSet::new()); let (tx, rx) = mpsc::channel(100); let tx_to_hold = Arc::new(Mutex::new(tx)); ( InvalidatedJWTStore { store: invalidated, tx: tx_to_hold, }, rx, ) } async fn add_to_invalidated(&self, authenticated: Authenticated) { self.store.insert(authenticated.jwt.clone()); let mut tx = self.tx.lock().await; if let Err(_e) = tx .send(InvalidatedTokensEvent::Add(authenticated.jwt)) .await { #[cfg(feature = "tracing")] error!(error = ?_e, "Failed to send update on adding to invalidated") } } } struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtSigningKeys { fn generate() -> Result> { let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?; let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?; let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref()); Ok(JwtSigningKeys { encoding_key, decoding_key, }) } } mod hasing { pub fn hash() { use argon2::{ password_hash::{ rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString }, Argon2 }; let password = b"hunter42"; // Bad password; don't actually use! let salt = SaltString::generate(&mut OsRng); // Argon2 with default params (Argon2id v19) let argon2 = Argon2::default(); // Hash password to PHC string ($argon2id$v=19$...) let password_hash = argon2.hash_password(password, &salt)?.to_string(); // Verify password against PHC string. // // NOTE: hash params from `parsed_hash` are used instead of what is configured in the // `Argon2` instance. let parsed_hash = PasswordHash::new(&password_hash)?; assert!(Argon2::default().verify_password(password, &parsed_hash).is_ok()); } }