use std::ops::Add; use std::sync::Arc; use actix_jwt_authc::*; use actix_web::web::{Data, Json, ServiceConfig}; use actix_web::{get, post, HttpResponse}; use dashmap::DashSet; use futures::channel::{mpsc, mpsc::Sender}; use futures::stream::Stream; use futures::SinkExt; use jsonwebtoken::*; use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use sea_orm::{DatabaseConnection, EntityTrait}; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; use tokio::sync::Mutex; use uuid::Uuid; pub type UserSession = Authenticated; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct Claims { exp: usize, iat: usize, sub: String, } #[derive(Serialize, Deserialize)] pub struct EmptyResponse {} #[derive(Debug, Serialize, Deserialize)] pub struct LoginResponse { bearer_token: String, claims: Claims, } const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, invalidated_jwts_store: Data, encoding_key: Data, factory: AuthenticateMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(logout) .service(session_info); } pub fn factory(&self) -> AuthenticateMiddlewareFactory { self.factory.clone() } pub fn new() -> Self { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_settings = AuthenticateMiddlewareSettings { jwt_decoding_key: jwt_signing_keys.decoding_key, jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), jwt_validator: validator, jwt_session_key: None, }; let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(); let auth_middleware_factory = AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); Self { invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } } } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct SignInPayload { login: String, password: String, } #[post("/login")] async fn login( jwt_encoding_key: Data, jwt_ttl: Data, db: Data, payload: Json, ) -> Result { let db = db.into_inner(); let payload = payload.into_inner(); let sub = format!("{}", Uuid::new_v4().as_u128()); let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0); let exp = expires_at.unix_timestamp() as usize; use sea_orm::*; let account = match oswilno_contract::accounts::Entity::find() .filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str())) .one(&*db) .await { Ok(a) => a, Err(e) => { tracing::warn!("Failed to find account: {e}"); return Err(Error::InternalError); } }; let jwt_claims = Claims { iat, exp, sub }; let jwt_token = encode( &Header::new(JWT_SIGNING_ALGO), &jwt_claims, &jwt_encoding_key, ) .map_err(|_| Error::InternalError)?; let login_response = LoginResponse { bearer_token: jwt_token, claims: jwt_claims, }; Ok(HttpResponse::Ok().json(login_response)) } #[get("/session")] async fn session_info(authenticated: UserSession) -> Result { Ok(HttpResponse::Ok().json(authenticated)) } #[get("/logout")] async fn logout( invalidated_jwts: Data, authenticated: Authenticated, ) -> Result { invalidated_jwts.add_to_invalidated(authenticated).await; Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] struct AccountInfo { login: String, password: String, } #[post("/register")] async fn register( invalidated_jwts: Data, authenticated: Authenticated, db: Data, payload: Json, ) -> Result { use sea_orm::*; use oswilno_contract::accounts::*; Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Clone)] struct InvalidatedJWTStore { store: Arc>, tx: Arc>>, } impl InvalidatedJWTStore { /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s fn new_with_stream() -> ( InvalidatedJWTStore, impl Stream, ) { let invalidated = Arc::new(DashSet::new()); let (tx, rx) = mpsc::channel(100); let tx_to_hold = Arc::new(Mutex::new(tx)); ( InvalidatedJWTStore { store: invalidated, tx: tx_to_hold, }, rx, ) } async fn add_to_invalidated(&self, authenticated: Authenticated) { self.store.insert(authenticated.jwt.clone()); let mut tx = self.tx.lock().await; if let Err(_e) = tx .send(InvalidatedTokensEvent::Add(authenticated.jwt)) .await { #[cfg(feature = "tracing")] error!(error = ?_e, "Failed to send update on adding to invalidated") } } } struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtSigningKeys { fn generate() -> Result> { let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?; let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?; let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref()); Ok(JwtSigningKeys { encoding_key, decoding_key, }) } } mod hasing { use argon2::{ password_hash::{ rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString, }, Argon2 }; pub fn encrypt(password: &str) -> argon2::password_hash::Result { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); argon2.hash_password(password.as_bytes(), &salt).map(|hash| hash.to_string()) } pub fn verify(password_hash: &str, password: &str) -> argon2::password_hash::Result<()> { let parsed_hash = PasswordHash::new(&password_hash)?; Argon2::default() .verify_password(password.as_bytes(), &parsed_hash) } #[cfg(test)] mod tests { use super::*; #[test] fn check_always_random_salt() { let pass = "ahs9dya8tsd7fa8tsa86tT&^R%^DS^%ARS&A"; let hash = encrypt(pass).unwrap(); assert!(verify( hash.as_str(), pass).is_ok()); } } }