oswilno/crates/oswilno-session/src/lib.rs
2023-08-14 12:30:32 +02:00

495 lines
14 KiB
Rust

use std::ops::Add;
use std::sync::Arc;
use actix_jwt_session::SessionStorage;
pub use actix_jwt_session::{Authenticated, Error, RedisMiddlewareFactory};
use actix_web::web::{Data, Form, ServiceConfig};
use actix_web::{get, post, HttpResponse};
use askama_actix::Template;
use autometrics::autometrics;
use garde::Validate;
use jsonwebtoken::*;
use oswilno_view::{Errors, Lang, Layout, TranslationStorage};
use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
mod extract_session;
mod hashing;
pub use oswilno_view::filters;
pub type UserSession = Claims;
#[derive(Clone, Copy)]
pub struct JWTTtl(std::time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
enum Audience {
Web,
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub struct Claims {
#[serde(rename = "exp")]
expires_at: usize,
#[serde(rename = "iat")]
issues_at: usize,
#[serde(rename = "sub")]
subject: String,
#[serde(rename = "aud")]
audience: Audience,
#[serde(rename = "jti")]
jwt_id: uuid::Uuid,
}
impl actix_jwt_session::Claims for Claims {
fn jti(&self) -> uuid::Uuid {
self.jwt_id
}
}
impl Claims {
pub fn account_id(&self) -> i32 {
self.subject
.split_once('-')
.filter(|(desc, _id)| *desc == "account")
.map(|(_d, id)| id)
.and_then(|id| id.parse().ok())
.unwrap_or_default()
}
}
#[derive(Serialize, Deserialize)]
pub struct EmptyResponse {}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
bearer_token: String,
claims: Claims,
}
#[derive(Clone)]
pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>,
factory: RedisMiddlewareFactory<Claims>,
}
impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) {
config
.app_data(self.jwt_ttl)
.service(login)
.service(login_view)
.service(login_partial_view)
.service(logout)
.service(session_info)
.service(register)
.service(register_view)
.service(register_partial_view);
}
pub fn factory(&self) -> RedisMiddlewareFactory<Claims> {
self.factory.clone()
}
pub fn translations(&self, l10n: &mut oswilno_view::TranslationStorage) {
l10n
// English
.with_lang(oswilno_view::Lang::En)
.add("Sign in", "Sign in")
.add("Sign up", "Sign up")
.add("Bad credentials", "Bad credentials")
.done()
// Polish
.with_lang(oswilno_view::Lang::Pl)
.add("Sign in", "Logowanie")
.add("Sign up", "Rejestracja")
.add("Bad credentials", "Złe dane uwierzytelniające")
.add("is taken", "jest zajęty")
.add("is not strong enough", "jest za słabe")
.add(
"length is lower than 8",
"długość jest mniejsza niż 8 znaków",
)
.add(
"Login or email already taken",
"Login lub adres e-mail jest zajęty",
)
.add("Password", "Hasło")
.add("Submit", "Wyślij")
.done();
}
pub fn new(redis: redis_async_pool::RedisPool) -> Self {
let jwt_ttl = JWTTtl(std::time::Duration::from_secs(31 * 60 * 60));
let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let auth_middleware_factory = RedisMiddlewareFactory::<Claims>::new(
Arc::new(jwt_signing_keys.encoding_key),
Arc::new(jwt_signing_keys.decoding_key),
Algorithm::EdDSA,
redis,
);
Self {
jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory,
}
}
}
#[derive(Template)]
#[template(path = "./sign-in/partial.html")]
struct SignInPartialTemplate {
form: SignInPayload,
lang: Lang,
t: Arc<TranslationStorage>,
errors: Errors,
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Default)]
pub struct SignInPayload {
login: String,
password: String,
}
#[get("/login")]
async fn login_view(t: Data<TranslationStorage>) -> Layout<SignInPartialTemplate> {
oswilno_view::Layout {
main: SignInPartialTemplate {
form: SignInPayload::default(),
lang: Lang::Pl,
t: t.into_inner(),
errors: Errors::default(),
},
}
}
#[get("/p/login")]
async fn login_partial_view(t: Data<TranslationStorage>) -> SignInPartialTemplate {
SignInPartialTemplate {
form: SignInPayload::default(),
lang: Lang::Pl,
t: t.into_inner(),
errors: Errors::default(),
}
}
#[autometrics]
#[post("/login")]
async fn login(
jwt_ttl: Data<JWTTtl>,
db: Data<DatabaseConnection>,
redis: Data<SessionStorage<Claims>>,
payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>,
lang: Lang,
) -> Result<HttpResponse, Error> {
let t = t.into_inner();
let mut errors = Errors::default();
match login_inner(
jwt_ttl.into_inner(),
payload.into_inner(),
db.into_inner(),
redis.into_inner(),
&mut errors,
)
.await
{
Ok(res) => Ok(HttpResponse::Ok().json(res)),
Err(form) => Ok(HttpResponse::Ok().body(
(SignInPartialTemplate {
form,
lang,
t,
errors,
})
.render()
.unwrap(),
)),
}
}
async fn login_inner(
jwt_ttl: Arc<JWTTtl>,
payload: SignInPayload,
db: Arc<DatabaseConnection>,
redis: Arc<SessionStorage<Claims>>,
errors: &mut Errors,
) -> Result<LoginResponse, SignInPayload> {
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0);
let exp = expires_at.unix_timestamp() as usize;
use sea_orm::*;
let account = match oswilno_contract::accounts::Entity::find()
.filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str()))
.one(&*db)
.await
{
Ok(Some(a)) => a,
Ok(None) => {
errors.push_global("Bad credentials");
return Err(payload);
}
Err(e) => {
tracing::warn!("Failed to find account: {e}");
errors.push_global("Bad credentials");
return Err(payload);
}
};
if hashing::verify(account.pass_hash.as_str(), payload.password.as_str()).is_err() {
errors.push_global("Bad credentials");
return Err(payload);
}
let jwt_claims = Claims {
issues_at: iat,
subject: format!("account-{}", account.id),
expires_at: exp,
audience: Audience::Web,
jwt_id: uuid::Uuid::new_v4(),
};
let jwt_token = match redis.store(jwt_claims.clone(), jwt_ttl.0).await {
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
Ok(jwt_token) => jwt_token,
};
let bearer_token = match jwt_token.encode() {
Ok(token) => token,
Err(e) => {
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
};
Ok(LoginResponse {
bearer_token,
claims: jwt_claims,
})
}
#[autometrics]
#[get("/session")]
async fn session_info(authenticated: Authenticated<Claims>) -> Result<HttpResponse, Error> {
Ok(HttpResponse::Ok().json(&*authenticated))
}
#[autometrics]
#[get("/logout")]
async fn logout(
authenticated: Authenticated<Claims>,
redis: Data<redis_async_pool::RedisPool>,
) -> Result<HttpResponse, Error> {
{
use redis::AsyncCommands;
let jwt_id = authenticated.jwt_id;
if let Ok(mut conn) = redis.get().await {
if conn.del::<_, String>(jwt_id.as_bytes()).await.is_err() {}
}
}
Ok(HttpResponse::Ok().json(EmptyResponse {}))
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, Eq, PartialEq, garde::Validate)]
#[garde(context(RegisterContext))]
struct AccountInfo {
#[garde(length(min = 4, max = 30), custom(is_login_free))]
#[serde(rename = "login")]
input_login: String,
#[garde(email, custom(is_email_free))]
email: String,
#[garde(length(min = 8, max = 50), custom(is_strong_password))]
password: String,
}
#[get("/register")]
async fn register_view(t: Data<TranslationStorage>) -> Layout<RegisterPartialTemplate> {
Layout {
main: RegisterPartialTemplate {
form: AccountInfo::default(),
t: t.into_inner(),
lang: Lang::Pl,
errors: oswilno_view::Errors::default(),
},
}
}
#[get("/p/register")]
async fn register_partial_view(t: Data<TranslationStorage>) -> RegisterPartialTemplate {
RegisterPartialTemplate {
form: AccountInfo::default(),
t: t.into_inner(),
lang: Lang::Pl,
errors: oswilno_view::Errors::default(),
}
}
#[derive(askama_actix::Template)]
#[template(path = "./register/partial.html")]
struct RegisterPartialTemplate {
form: AccountInfo,
t: Arc<TranslationStorage>,
lang: Lang,
errors: oswilno_view::Errors,
}
#[autometrics]
#[post("/register")]
async fn register(
db: Data<DatabaseConnection>,
payload: Form<AccountInfo>,
t: Data<oswilno_view::TranslationStorage>,
lang: Lang,
) -> Result<HttpResponse, Error> {
let t = t.into_inner();
let mut errors = oswilno_view::Errors::default();
Ok(
match register_internal(db.into_inner(), payload.into_inner(), &mut errors).await {
Ok(res) => res,
Err(p) => HttpResponse::BadRequest().body(
RegisterPartialTemplate {
form: p,
t,
lang,
errors,
}
.render()
.unwrap(),
),
},
)
}
struct RegisterContext {
login_taken: bool,
email_taken: bool,
}
fn is_email_free(_value: &str, context: &RegisterContext) -> garde::Result {
if context.email_taken {
return Err(garde::Error::new("is taken"));
}
Ok(())
}
fn is_login_free(_value: &str, context: &RegisterContext) -> garde::Result {
if context.login_taken {
return Err(garde::Error::new("is taken"));
}
Ok(())
}
static WEAK_PASS: &str = "is not strong enough";
fn is_strong_password(value: &str, _context: &RegisterContext) -> garde::Result {
if !(8..50).contains(&value.len()) {
return Err(garde::Error::new(WEAK_PASS));
}
let mut num = false;
let mut low = false;
let mut up = false;
let mut spec = false;
for c in value.chars() {
if num && low && up && spec {
return Ok(());
}
num = num || c.is_numeric();
low = low || c.is_lowercase();
up = up || c.is_uppercase();
spec = spec || !c.is_alphanumeric();
}
return Err(garde::Error::new(WEAK_PASS));
}
async fn register_internal(
db: Arc<DatabaseConnection>,
p: AccountInfo,
errors: &mut oswilno_view::Errors,
) -> Result<HttpResponse, AccountInfo> {
use oswilno_contract::accounts::*;
use sea_orm::*;
let query_result = db
.query_one(sea_orm::Statement::from_sql_and_values(
sea_orm::DbBackend::Postgres,
"select login = $1 as login_taken, email = $2 as email_taken from accounts",
[p.input_login.clone().into(), p.email.clone().into()],
))
.await
.map_err(|e| {
tracing::error!("{e}");
errors.push_global("Something went wrong");
p.clone()
})?;
let (login_taken, email_taken) = if let Some(query_result) = query_result {
let Ok((login_taken, email_taken)): Result<(bool,bool), _> = query_result.try_get_many("", &["login_taken".into(), "email_taken".into()]) else {
tracing::warn!("Failed to fetch fields from query result while checking if account info exists in db");
errors.push_global("Something went wrong");
return Err(p);
};
(login_taken, email_taken)
} else {
(false, false)
};
if let Err(e) = p.validate(&RegisterContext {
login_taken,
email_taken,
}) {
errors.consume_garde(e);
return Err(p);
}
tracing::warn!("{errors:#?}");
let pass = match hashing::encrypt(p.password.as_str()) {
Ok(p) => p,
Err(e) => {
tracing::warn!("{e}");
return Ok(HttpResponse::InternalServerError().body(""));
}
};
let model = (ActiveModel {
id: NotSet,
login: Set(p.input_login.to_string()),
email: Set(p.email.to_string()),
pass_hash: Set(pass),
..Default::default()
})
.save(&*db)
.await
.map_err(|e| {
tracing::warn!("{e}");
errors.push_global("Login or email already taken");
p
})?;
tracing::info!("{model:?}");
Ok(HttpResponse::SeeOther()
.append_header(("Location", "/login"))
.json(EmptyResponse {}))
}
pub struct JwtSigningKeys {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl JwtSigningKeys {
fn generate() -> Result<Self, Box<dyn std::error::Error>> {
let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
Ok(JwtSigningKeys {
encoding_key,
decoding_key,
})
}
}