use std::ops::Add; use std::sync::Arc; use actix_jwt_authc::*; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpResponse}; use askama_actix::Template; use autometrics::autometrics; use futures::channel::{mpsc, mpsc::Sender}; use futures::stream::Stream; use futures::SinkExt; use jsonwebtoken::*; use oswilno_view::{Lang, TranslationStorage}; use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; use tokio::sync::Mutex; mod hashing; pub use oswilno_view::filters; pub type UserSession = Authenticated; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] enum Audience { Web, } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)] pub struct Claims { #[serde(rename = "exp")] expires_at: usize, #[serde(rename = "iat")] issues_at: usize, #[serde(rename = "sub")] subject: String, #[serde(rename = "aud")] audience: Audience, #[serde(rename = "jti")] jwt_id: uuid::Uuid, } #[derive(Serialize, Deserialize)] pub struct EmptyResponse {} #[derive(Debug, Serialize, Deserialize)] pub struct LoginResponse { bearer_token: String, claims: Claims, } const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, invalidated_jwts_store: Data, encoding_key: Data, factory: AuthenticateMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(login_view) .service(logout) .service(session_info) .service(register) .service(register_view); } pub fn factory(&self) -> AuthenticateMiddlewareFactory { self.factory.clone() } pub fn translations(&self, l10n: &mut oswilno_view::TranslationStorage) { l10n // English .with_lang(oswilno_view::Lang::En) .add("Sign in", "Sign in") .add("Sign up", "Sign up") .add("Bad credentials", "Bad credentials") .done() // Polish .with_lang(oswilno_view::Lang::Pl) .add("Sign in", "Logowanie") .add("Sign up", "Rejestracja") .add("Bad credentials", "Złe dane uwierzytelniające") .add("Login already taken", "Login jest zajęty") .add( "Login or email already taken", "Login lub adres e-mail jest zajęty", ) .add("Submit", "Wyślij") .done(); } pub fn new() -> Self { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_settings = AuthenticateMiddlewareSettings { jwt_decoding_key: jwt_signing_keys.decoding_key, jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), jwt_validator: validator, jwt_session_key: None, }; let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(); let auth_middleware_factory = AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); Self { invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } } } #[derive(Template)] #[template(path = "./sign-in/full.html")] struct SignInTemplate { form: SignInPayload, lang: Lang, t: Arc, } #[derive(Template)] #[template(path = "./sign-in/partial.html")] struct SignInPartialTemplate { form: SignInPayload, lang: Lang, t: Arc, } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Default)] pub struct SignInPayload { #[serde(skip, default)] errors: Vec, login: String, #[serde(skip, default)] login_errors: Vec, password: String, #[serde(skip, default)] password_errors: Vec, } #[get("/login")] async fn login_view(t: Data) -> SignInTemplate { SignInTemplate { form: SignInPayload::default(), lang: Lang::Pl, t: t.into_inner(), } } #[autometrics] #[post("/login")] async fn login( jwt_encoding_key: Data, jwt_ttl: Data, db: Data, payload: Form, t: Data, ) -> Result { let t = t.into_inner(); match login_inner(jwt_encoding_key, jwt_ttl, db, payload).await { Ok(res) => Ok(HttpResponse::Ok().json(res)), Err(form) => Ok(HttpResponse::Ok().body( SignInPartialTemplate { form, lang: Lang::Pl, t, } .render() .unwrap(), )), } } async fn login_inner( jwt_encoding_key: Data, jwt_ttl: Data, db: Data, payload: Form, ) -> Result { let db = db.into_inner(); let mut payload = payload.into_inner(); let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0); let exp = expires_at.unix_timestamp() as usize; use sea_orm::*; let account = match oswilno_contract::accounts::Entity::find() .filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str())) .one(&*db) .await { Ok(Some(a)) => a, Ok(None) => { payload.errors.push("Bad credentials".into()); return Err(payload); } Err(e) => { tracing::warn!("Failed to find account: {e}"); payload.errors.push("Bad credentials".into()); return Err(payload); } }; if hashing::verify(account.pass_hash.as_str(), payload.password.as_str()).is_err() { payload.errors.push("Bad credentials".into()); return Err(payload); } let jwt_claims = Claims { issues_at: iat, subject: format!("account-{}", account.id), expires_at: exp, audience: Audience::Web, jwt_id: uuid::Uuid::new_v4(), }; let jwt_token = encode( &Header::new(JWT_SIGNING_ALGO), &jwt_claims, &jwt_encoding_key, ) .map_err(|_| { payload.errors.push("Bad credentials".into()); payload })?; Ok(LoginResponse { bearer_token: jwt_token, claims: jwt_claims, }) } #[autometrics] #[get("/session")] async fn session_info(authenticated: UserSession) -> Result { Ok(HttpResponse::Ok().json(authenticated)) } #[autometrics] #[get("/logout")] async fn logout( invalidated_jwts: Data, authenticated: Authenticated, ) -> Result { invalidated_jwts.add_to_invalidated(authenticated).await; Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Debug, Default, Serialize, Deserialize, Clone, Eq, PartialEq)] struct AccountInfo { #[serde(skip)] errors: Vec, login: String, #[serde(skip)] login_errors: Vec, email: String, #[serde(skip)] email_errors: Vec, password: String, #[serde(skip)] password_errors: Vec, } #[derive(askama_actix::Template)] #[template(path = "./register/full.html")] struct RegisterTemplate { form: AccountInfo, t: Arc, lang: Lang, } #[get("/register")] async fn register_view(t: Data) -> RegisterTemplate { RegisterTemplate { form: AccountInfo::default(), t: t.into_inner(), lang: Lang::Pl, } } #[derive(askama_actix::Template)] #[template(path = "./register/partial.html")] struct RegisterPartialTemplate { form: AccountInfo, t: Arc, lang: Lang, } #[autometrics] #[post("/register")] async fn register( db: Data, payload: Form, t: Data, ) -> Result { let t = t.into_inner(); Ok( match register_internal(db.into_inner(), payload.into_inner(), t.clone()).await { Ok(res) => res, Err(p) => HttpResponse::BadRequest().body( RegisterPartialTemplate { form: p, t, lang: Lang::Pl, } .render() .unwrap(), ), }, ) } async fn register_internal( db: Arc, mut p: AccountInfo, t: Arc, ) -> Result { use oswilno_contract::accounts::*; use sea_orm::*; let pass = match hashing::encrypt(p.password.as_str()) { Ok(p) => p, Err(e) => { tracing::warn!("{e}"); return Ok(HttpResponse::InternalServerError().body("")); } }; match Entity::find() .filter(Column::Login.eq(&p.login)) .one(&*db) .await { Ok(None) | Err(_) => { p.login_errors.push("Login already taken".into()); return Err(p); } _ => (), }; let model = match (ActiveModel { id: NotSet, login: Set(p.login.to_string()), // email: Set(p.email.to_string()), pass_hash: Set(pass), ..Default::default() }) .save(&*db) .await { Ok(model) => model, Err(e) => { tracing::warn!("{e}"); p.login_errors .push(t.to_lang(Lang::Pl, "Login or email already taken")); return Err(p); } }; tracing::info!("{model:?}"); Ok(HttpResponse::SeeOther() .append_header(("Location", "/login")) .json(EmptyResponse {})) } #[derive(Clone)] struct InvalidatedJWTStore { // store: Arc>, tx: Arc>>, } impl InvalidatedJWTStore { /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s fn new_with_stream() -> ( InvalidatedJWTStore, impl Stream, ) { // let invalidated = Arc::new(DashSet::new()); let (tx, rx) = mpsc::channel(100); let tx_to_hold = Arc::new(Mutex::new(tx)); ( InvalidatedJWTStore { // store: invalidated, tx: tx_to_hold, }, rx, ) } async fn add_to_invalidated(&self, authenticated: Authenticated) { // self.store.insert(authenticated.jwt.clone()); let mut tx = self.tx.lock().await; if let Err(_e) = tx .send(InvalidatedTokensEvent::Add(authenticated.jwt)) .await { #[cfg(feature = "tracing")] error!(error = ?_e, "Failed to send update on adding to invalidated") } } } struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtSigningKeys { fn generate() -> Result> { let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?; let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?; let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref()); Ok(JwtSigningKeys { encoding_key, decoding_key, }) } }