use std::ops::Add; use std::sync::Arc; use actix_jwt_authc::*; use actix_web::web::{Data, Form, ServiceConfig}; use actix_web::{get, post, HttpResponse}; use askama_actix::Template; use autometrics::autometrics; use futures::channel::{mpsc, mpsc::Sender}; use futures::stream::Stream; use futures::SinkExt; use garde::Validate; use jsonwebtoken::*; use oswilno_view::{Errors, Lang, TranslationStorage}; use ring::rand::SystemRandom; use ring::signature::{Ed25519KeyPair, KeyPair}; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use time::ext::*; use time::OffsetDateTime; use tokio::sync::Mutex; mod extract_session; mod hashing; pub use oswilno_view::filters; pub type UserSession = Authenticated; #[derive(Clone, Copy)] pub struct JWTTtl(time::Duration); #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[serde(rename_all = "snake_case")] enum Audience { Web, } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)] #[serde(rename_all = "snake_case")] pub struct Claims { #[serde(rename = "exp")] expires_at: usize, #[serde(rename = "iat")] issues_at: usize, #[serde(rename = "sub")] subject: String, #[serde(rename = "aud")] audience: Audience, #[serde(rename = "jti")] jwt_id: uuid::Uuid, } #[derive(Serialize, Deserialize)] pub struct EmptyResponse {} #[derive(Debug, Serialize, Deserialize)] pub struct LoginResponse { bearer_token: String, claims: Claims, } const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA; #[derive(Clone)] pub struct SessionConfigurator { jwt_ttl: Data, invalidated_jwts_store: Data, encoding_key: Data, factory: AuthenticateMiddlewareFactory, } impl SessionConfigurator { pub fn app_data(self, config: &mut ServiceConfig) { config .app_data(self.invalidated_jwts_store) .app_data(self.encoding_key) .app_data(self.jwt_ttl) .service(login) .service(login_view) .service(login_partial_view) .service(logout) .service(session_info) .service(register) .service(register_view) .service(register_partial_view); } pub fn factory(&self) -> AuthenticateMiddlewareFactory { self.factory.clone() } pub fn translations(&self, l10n: &mut oswilno_view::TranslationStorage) { l10n // English .with_lang(oswilno_view::Lang::En) .add("Sign in", "Sign in") .add("Sign up", "Sign up") .add("Bad credentials", "Bad credentials") .done() // Polish .with_lang(oswilno_view::Lang::Pl) .add("Sign in", "Logowanie") .add("Sign up", "Rejestracja") .add("Bad credentials", "Złe dane uwierzytelniające") .add("is taken", "jest zajęty") .add("is not strong enough", "jest za słabe") .add( "length is lower than 8", "długość jest mniejsza niż 8 znaków", ) .add( "Login or email already taken", "Login lub adres e-mail jest zajęty", ) .add("Password", "Hasło") .add("Submit", "Wyślij") .done(); } pub fn new(redis: redis_async_pool::RedisPool) -> Self { let jwt_ttl = JWTTtl(31.days()); let jwt_signing_keys = JwtSigningKeys::generate().unwrap(); let validator = Validation::new(JWT_SIGNING_ALGO); let auth_middleware_settings = AuthenticateMiddlewareSettings { jwt_decoding_key: jwt_signing_keys.decoding_key, jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]), jwt_validator: validator, jwt_session_key: None, }; let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(redis); let auth_middleware_factory = AuthenticateMiddlewareFactory::::new(stream, auth_middleware_settings.clone()); Self { invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()), encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()), jwt_ttl: Data::new(jwt_ttl.clone()), factory: auth_middleware_factory, } } } #[derive(Template)] #[template(path = "./sign-in/full.html")] struct SignInTemplate { form: SignInPayload, lang: Lang, t: Arc, errors: Errors, } #[derive(Template)] #[template(path = "./sign-in/partial.html")] struct SignInPartialTemplate { form: SignInPayload, lang: Lang, t: Arc, errors: Errors, } #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Default)] pub struct SignInPayload { login: String, password: String, } #[get("/login")] async fn login_view(t: Data) -> SignInTemplate { SignInTemplate { form: SignInPayload::default(), lang: Lang::Pl, t: t.into_inner(), errors: Errors::default(), } } #[get("/p/login")] async fn login_partial_view(t: Data) -> SignInPartialTemplate { SignInPartialTemplate { form: SignInPayload::default(), lang: Lang::Pl, t: t.into_inner(), errors: Errors::default(), } } #[autometrics] #[post("/login")] async fn login( jwt_encoding_key: Data, jwt_ttl: Data, db: Data, payload: Form, t: Data, lang: Lang, redis: Data, ) -> Result { let t = t.into_inner(); let mut errors = Errors::default(); match login_inner( jwt_encoding_key, jwt_ttl.into_inner(), payload.into_inner(), db.into_inner(), redis.into_inner(), &mut errors, ) .await { Ok(res) => Ok(HttpResponse::Ok().json(res)), Err(form) => Ok(HttpResponse::Ok().body( (SignInPartialTemplate { form, lang, t, errors, }) .render() .unwrap(), )), } } async fn login_inner( jwt_encoding_key: Data, jwt_ttl: Arc, payload: SignInPayload, db: Arc, redis: Arc, errors: &mut Errors, ) -> Result { let iat = OffsetDateTime::now_utc().unix_timestamp() as usize; let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0); let exp = expires_at.unix_timestamp() as usize; use sea_orm::*; let account = match oswilno_contract::accounts::Entity::find() .filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str())) .one(&*db) .await { Ok(Some(a)) => a, Ok(None) => { errors.push_global("Bad credentials"); return Err(payload); } Err(e) => { tracing::warn!("Failed to find account: {e}"); errors.push_global("Bad credentials"); return Err(payload); } }; if hashing::verify(account.pass_hash.as_str(), payload.password.as_str()).is_err() { errors.push_global("Bad credentials"); return Err(payload); } let jwt_claims = Claims { issues_at: iat, subject: format!("account-{}", account.id), expires_at: exp, audience: Audience::Web, jwt_id: uuid::Uuid::new_v4(), }; let jwt_token = encode( &Header::new(JWT_SIGNING_ALGO), &jwt_claims, &jwt_encoding_key, ) .map_err(|_| { errors.push_global("Bad credentials"); payload.clone() })?; let Ok(bin_value) = bincode::serialize(&jwt_claims) else { errors.push_global("Failed to sign in. Please try later"); return Err(payload.clone()); }; { use redis::AsyncCommands; let Ok(mut conn) = redis.get().await else { errors.push_global("Failed to sign in. Please try later"); return Err(payload); }; if let Err(e) = conn .set_ex::<'_, _, _, String>( jwt_claims.jwt_id.as_bytes(), bin_value, jwt_ttl.0.as_seconds_f32() as usize, ) .await { tracing::warn!("Failed to set sign-in claims in redis: {e}"); errors.push_global("Failed to sign in. Please try later"); return Err(payload); } } Ok(LoginResponse { bearer_token: jwt_token, claims: jwt_claims, }) } #[autometrics] #[get("/session")] async fn session_info(authenticated: UserSession) -> Result { Ok(HttpResponse::Ok().json(authenticated)) } #[autometrics] #[get("/logout")] async fn logout( authenticated: Authenticated, redis: Data, ) -> Result { { use redis::AsyncCommands; let jwt_id = authenticated.claims.jwt_id; if let Ok(mut conn) = redis.get().await { if conn.del::<_, String>(jwt_id.as_bytes()).await.is_err() {} } } Ok(HttpResponse::Ok().json(EmptyResponse {})) } #[derive(Debug, Default, Serialize, Deserialize, Clone, Eq, PartialEq, garde::Validate)] #[garde(context(RegisterContext))] struct AccountInfo { #[garde(length(min = 4, max = 30), custom(is_login_free))] #[serde(rename = "login")] input_login: String, #[garde(email, custom(is_email_free))] email: String, #[garde(length(min = 8, max = 50), custom(is_strong_password))] password: String, } #[derive(askama_actix::Template)] #[template(path = "./register/full.html")] struct RegisterTemplate { form: AccountInfo, t: Arc, lang: Lang, errors: oswilno_view::Errors, } #[get("/register")] async fn register_view(t: Data) -> RegisterTemplate { RegisterTemplate { form: AccountInfo::default(), t: t.into_inner(), lang: Lang::Pl, errors: oswilno_view::Errors::default(), } } #[get("/p/register")] async fn register_partial_view(t: Data) -> RegisterTemplate { RegisterTemplate { form: AccountInfo::default(), t: t.into_inner(), lang: Lang::Pl, errors: oswilno_view::Errors::default(), } } #[derive(askama_actix::Template)] #[template(path = "./register/partial.html")] struct RegisterPartialTemplate { form: AccountInfo, t: Arc, lang: Lang, errors: oswilno_view::Errors, } #[autometrics] #[post("/register")] async fn register( db: Data, payload: Form, t: Data, lang: Lang, ) -> Result { let t = t.into_inner(); let mut errors = oswilno_view::Errors::default(); Ok( match register_internal(db.into_inner(), payload.into_inner(), &mut errors).await { Ok(res) => res, Err(p) => HttpResponse::BadRequest().body( RegisterPartialTemplate { form: p, t, lang, errors, } .render() .unwrap(), ), }, ) } struct RegisterContext { login_taken: bool, email_taken: bool, } fn is_email_free(_value: &str, context: &RegisterContext) -> garde::Result { if context.email_taken { return Err(garde::Error::new("is taken")); } Ok(()) } fn is_login_free(_value: &str, context: &RegisterContext) -> garde::Result { if context.login_taken { return Err(garde::Error::new("is taken")); } Ok(()) } static WEAK_PASS: &str = "is not strong enough"; fn is_strong_password(value: &str, _context: &RegisterContext) -> garde::Result { if !(8..50).contains(&value.len()) { return Err(garde::Error::new(WEAK_PASS)); } let mut num = false; let mut low = false; let mut up = false; let mut spec = false; for c in value.chars() { if num && low && up && spec { return Ok(()); } num = num || c.is_numeric(); low = low || c.is_lowercase(); up = up || c.is_uppercase(); spec = spec || !c.is_alphanumeric(); } return Err(garde::Error::new(WEAK_PASS)); } async fn register_internal( db: Arc, p: AccountInfo, errors: &mut oswilno_view::Errors, ) -> Result { use oswilno_contract::accounts::*; use sea_orm::*; let query_result = db .query_one(sea_orm::Statement::from_sql_and_values( sea_orm::DbBackend::Postgres, "select login = $1 as login_taken, email = $2 as email_taken from accounts", [p.input_login.clone().into(), p.email.clone().into()], )) .await .map_err(|e| { tracing::error!("{e}"); errors.push_global("Something went wrong"); p.clone() })?; let (login_taken, email_taken) = if let Some(query_result) = query_result { let Ok((login_taken, email_taken)): Result<(bool,bool), _> = query_result.try_get_many("", &["login_taken".into(), "email_taken".into()]) else { tracing::warn!("Failed to fetch fields from query result while checking if account info exists in db"); errors.push_global("Something went wrong"); return Err(p); }; (login_taken, email_taken) } else { (false, false) }; if let Err(e) = p.validate(&RegisterContext { login_taken, email_taken, }) { errors.consume_garde(e); return Err(p); } tracing::warn!("{errors:#?}"); let pass = match hashing::encrypt(p.password.as_str()) { Ok(p) => p, Err(e) => { tracing::warn!("{e}"); return Ok(HttpResponse::InternalServerError().body("")); } }; let model = (ActiveModel { id: NotSet, login: Set(p.input_login.to_string()), email: Set(p.email.to_string()), pass_hash: Set(pass), ..Default::default() }) .save(&*db) .await .map_err(|e| { tracing::warn!("{e}"); errors.push_global("Login or email already taken"); p })?; tracing::info!("{model:?}"); Ok(HttpResponse::SeeOther() .append_header(("Location", "/login")) .json(EmptyResponse {})) } #[derive(Clone)] struct InvalidatedJWTStore { // store: Arc>, redis: redis_async_pool::RedisPool, tx: Arc>>, } impl InvalidatedJWTStore { /// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s fn new_with_stream( redis: redis_async_pool::RedisPool, ) -> ( InvalidatedJWTStore, impl Stream, ) { // let invalidated = Arc::new(DashSet::new()); let (tx, rx) = mpsc::channel(100); let tx_to_hold = Arc::new(Mutex::new(tx)); ( InvalidatedJWTStore { // store: invalidated, redis, tx: tx_to_hold, }, rx, ) } async fn add_to_invalidated(&self, authenticated: Authenticated) { // self.store.insert(authenticated.jwt.clone()); let mut tx = self.tx.lock().await; if let Err(_e) = tx .send(InvalidatedTokensEvent::Add(authenticated.jwt)) .await { #[cfg(feature = "tracing")] error!(error = ?_e, "Failed to send update on adding to invalidated") } } } struct JwtSigningKeys { encoding_key: EncodingKey, decoding_key: DecodingKey, } impl JwtSigningKeys { fn generate() -> Result> { let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?; let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?; let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref()); Ok(JwtSigningKeys { encoding_key, decoding_key, }) } }