oswilno/crates/oswilno-session/src/lib.rs

551 lines
16 KiB
Rust
Raw Normal View History

2023-08-01 16:29:03 +02:00
use std::ops::Add;
use std::sync::Arc;
use actix_jwt_authc::*;
2023-08-04 16:32:10 +02:00
use actix_web::web::{Data, Form, ServiceConfig};
2023-08-01 22:38:56 +02:00
use actix_web::{get, post, HttpResponse};
2023-08-09 16:35:37 +02:00
use askama_actix::{Template, TemplateToResponse as _};
2023-08-03 16:16:46 +02:00
use autometrics::autometrics;
2023-08-01 22:06:04 +02:00
use futures::channel::{mpsc, mpsc::Sender};
2023-08-01 16:29:03 +02:00
use futures::stream::Stream;
2023-08-01 22:06:04 +02:00
use futures::SinkExt;
2023-08-05 22:20:23 +02:00
use garde::Validate;
2023-08-01 16:29:03 +02:00
use jsonwebtoken::*;
2023-08-09 15:42:29 +02:00
use oswilno_view::{Errors, Lang, TranslationStorage};
2023-08-01 16:29:03 +02:00
use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
2023-08-02 12:45:29 +02:00
use sea_orm::DatabaseConnection;
2023-08-01 16:29:03 +02:00
use serde::{Deserialize, Serialize};
use time::ext::*;
use time::OffsetDateTime;
use tokio::sync::Mutex;
2023-08-02 12:45:29 +02:00
mod hashing;
2023-08-01 16:29:03 +02:00
2023-08-04 22:39:04 +02:00
pub use oswilno_view::filters;
2023-08-01 22:06:04 +02:00
pub type UserSession = Authenticated<Claims>;
#[derive(Clone, Copy)]
pub struct JWTTtl(time::Duration);
2023-08-06 22:14:03 +02:00
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
2023-08-02 12:45:29 +02:00
enum Audience {
Web,
}
2023-08-06 22:14:03 +02:00
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Hash)]
2023-08-01 22:06:04 +02:00
pub struct Claims {
2023-08-02 12:45:29 +02:00
#[serde(rename = "exp")]
expires_at: usize,
#[serde(rename = "iat")]
issues_at: usize,
#[serde(rename = "sub")]
subject: String,
#[serde(rename = "aud")]
audience: Audience,
#[serde(rename = "jti")]
jwt_id: uuid::Uuid,
2023-08-01 16:29:03 +02:00
}
2023-08-01 22:06:04 +02:00
#[derive(Serialize, Deserialize)]
pub struct EmptyResponse {}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
bearer_token: String,
claims: Claims,
2023-08-01 16:29:03 +02:00
}
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
2023-08-01 22:06:04 +02:00
#[derive(Clone)]
pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>,
invalidated_jwts_store: Data<InvalidatedJWTStore>,
encoding_key: Data<EncodingKey>,
factory: AuthenticateMiddlewareFactory<Claims>,
}
impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) {
config
.app_data(self.invalidated_jwts_store)
.app_data(self.encoding_key)
.app_data(self.jwt_ttl)
.service(login)
2023-08-03 16:16:46 +02:00
.service(login_view)
2023-08-05 14:48:13 +02:00
.service(login_partial_view)
2023-08-01 22:06:04 +02:00
.service(logout)
2023-08-04 22:39:04 +02:00
.service(session_info)
.service(register)
2023-08-05 14:48:13 +02:00
.service(register_view)
.service(register_partial_view);
2023-08-01 22:06:04 +02:00
}
2023-08-02 08:56:53 +02:00
pub fn factory(&self) -> AuthenticateMiddlewareFactory<Claims> {
self.factory.clone()
}
2023-08-01 22:06:04 +02:00
2023-08-02 16:37:03 +02:00
pub fn translations(&self, l10n: &mut oswilno_view::TranslationStorage) {
2023-08-03 16:16:46 +02:00
l10n
// English
2023-08-04 22:39:04 +02:00
.with_lang(oswilno_view::Lang::En)
2023-08-02 16:37:03 +02:00
.add("Sign in", "Sign in")
.add("Sign up", "Sign up")
2023-08-03 16:16:46 +02:00
.add("Bad credentials", "Bad credentials")
2023-08-02 16:37:03 +02:00
.done()
2023-08-03 16:16:46 +02:00
// Polish
2023-08-04 22:39:04 +02:00
.with_lang(oswilno_view::Lang::Pl)
2023-08-02 16:37:03 +02:00
.add("Sign in", "Logowanie")
.add("Sign up", "Rejestracja")
2023-08-03 16:16:46 +02:00
.add("Bad credentials", "Złe dane uwierzytelniające")
2023-08-05 22:20:23 +02:00
.add("is taken", "jest zajęty")
.add("is not strong enough", "jest za słabe")
.add(
"length is lower than 8",
"długość jest mniejsza niż 8 znaków",
)
2023-08-04 22:39:04 +02:00
.add(
"Login or email already taken",
"Login lub adres e-mail jest zajęty",
)
2023-08-05 22:20:23 +02:00
.add("Password", "Hasło")
2023-08-04 22:39:04 +02:00
.add("Submit", "Wyślij")
2023-08-02 16:37:03 +02:00
.done();
}
2023-08-09 15:42:29 +02:00
pub fn new(redis: redis_async_pool::RedisPool) -> Self {
2023-08-01 22:06:04 +02:00
let jwt_ttl = JWTTtl(31.days());
let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let validator = Validation::new(JWT_SIGNING_ALGO);
let auth_middleware_settings = AuthenticateMiddlewareSettings {
jwt_decoding_key: jwt_signing_keys.decoding_key,
jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]),
jwt_validator: validator,
jwt_session_key: None,
};
2023-08-09 15:42:29 +02:00
let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream(redis);
2023-08-01 22:06:04 +02:00
let auth_middleware_factory =
AuthenticateMiddlewareFactory::<Claims>::new(stream, auth_middleware_settings.clone());
Self {
invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()),
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory,
}
}
2023-08-01 16:29:03 +02:00
}
2023-08-03 16:16:46 +02:00
#[derive(Template)]
#[template(path = "./sign-in/full.html")]
struct SignInTemplate {
form: SignInPayload,
2023-08-04 22:39:04 +02:00
lang: Lang,
t: Arc<TranslationStorage>,
2023-08-09 15:42:29 +02:00
errors: Errors,
2023-08-03 16:16:46 +02:00
}
#[derive(Template)]
#[template(path = "./sign-in/partial.html")]
struct SignInPartialTemplate {
form: SignInPayload,
2023-08-04 22:39:04 +02:00
lang: Lang,
t: Arc<TranslationStorage>,
2023-08-09 15:42:29 +02:00
errors: Errors,
2023-08-03 16:16:46 +02:00
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Default)]
2023-08-01 22:38:56 +02:00
pub struct SignInPayload {
login: String,
password: String,
2023-08-03 16:16:46 +02:00
}
#[get("/login")]
2023-08-04 22:39:04 +02:00
async fn login_view(t: Data<TranslationStorage>) -> SignInTemplate {
2023-08-04 16:32:10 +02:00
SignInTemplate {
form: SignInPayload::default(),
2023-08-04 22:39:04 +02:00
lang: Lang::Pl,
t: t.into_inner(),
2023-08-09 15:42:29 +02:00
errors: Errors::default(),
2023-08-04 16:32:10 +02:00
}
2023-08-01 22:38:56 +02:00
}
2023-08-05 14:48:13 +02:00
#[get("/p/login")]
async fn login_partial_view(t: Data<TranslationStorage>) -> SignInPartialTemplate {
SignInPartialTemplate {
form: SignInPayload::default(),
lang: Lang::Pl,
t: t.into_inner(),
2023-08-09 15:42:29 +02:00
errors: Errors::default(),
2023-08-05 14:48:13 +02:00
}
}
2023-08-01 22:38:56 +02:00
2023-08-03 16:16:46 +02:00
#[autometrics]
2023-08-01 22:38:56 +02:00
#[post("/login")]
2023-08-01 16:29:03 +02:00
async fn login(
jwt_encoding_key: Data<EncodingKey>,
2023-08-01 22:06:04 +02:00
jwt_ttl: Data<JWTTtl>,
2023-08-01 22:38:56 +02:00
db: Data<DatabaseConnection>,
2023-08-03 16:16:46 +02:00
payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>,
2023-08-06 22:14:03 +02:00
lang: Lang,
2023-08-09 15:42:29 +02:00
redis: Data<redis_async_pool::RedisPool>,
2023-08-01 16:29:03 +02:00
) -> Result<HttpResponse, Error> {
2023-08-04 22:39:04 +02:00
let t = t.into_inner();
2023-08-09 15:42:29 +02:00
let mut errors = Errors::default();
match login_inner(
jwt_encoding_key,
jwt_ttl,
payload.into_inner(),
db.into_inner(),
redis.into_inner(),
&mut errors,
)
.await
{
2023-08-03 16:16:46 +02:00
Ok(res) => Ok(HttpResponse::Ok().json(res)),
2023-08-09 15:42:29 +02:00
Err(form) => Ok(HttpResponse::Ok().body(
2023-08-09 16:35:37 +02:00
(SignInPartialTemplate { form, lang, t, errors })
2023-08-09 15:42:29 +02:00
.render()
.unwrap(),
)),
2023-08-03 16:16:46 +02:00
}
}
async fn login_inner(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Data<JWTTtl>,
2023-08-09 15:42:29 +02:00
payload: SignInPayload,
db: Arc<DatabaseConnection>,
redis: Arc<redis_async_pool::RedisPool>,
errors: &mut Errors,
2023-08-03 16:16:46 +02:00
) -> Result<LoginResponse, SignInPayload> {
2023-08-01 16:29:03 +02:00
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0);
let exp = expires_at.unix_timestamp() as usize;
2023-08-01 22:38:56 +02:00
use sea_orm::*;
2023-08-02 08:56:53 +02:00
let account = match oswilno_contract::accounts::Entity::find()
.filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str()))
.one(&*db)
.await
{
2023-08-02 12:45:29 +02:00
Ok(Some(a)) => a,
Ok(None) => {
2023-08-09 15:42:29 +02:00
errors.push_global("Bad credentials");
2023-08-03 16:16:46 +02:00
return Err(payload);
2023-08-02 12:45:29 +02:00
}
2023-08-01 22:38:56 +02:00
Err(e) => {
tracing::warn!("Failed to find account: {e}");
2023-08-09 15:42:29 +02:00
errors.push_global("Bad credentials");
2023-08-03 16:16:46 +02:00
return Err(payload);
2023-08-01 22:38:56 +02:00
}
};
2023-08-02 12:45:29 +02:00
if hashing::verify(account.pass_hash.as_str(), payload.password.as_str()).is_err() {
2023-08-09 15:42:29 +02:00
errors.push_global("Bad credentials");
2023-08-03 16:16:46 +02:00
return Err(payload);
2023-08-02 12:45:29 +02:00
}
2023-08-01 22:38:56 +02:00
2023-08-02 12:45:29 +02:00
let jwt_claims = Claims {
issues_at: iat,
subject: format!("account-{}", account.id),
expires_at: exp,
audience: Audience::Web,
jwt_id: uuid::Uuid::new_v4(),
};
2023-08-01 16:29:03 +02:00
let jwt_token = encode(
&Header::new(JWT_SIGNING_ALGO),
&jwt_claims,
&jwt_encoding_key,
)
2023-08-03 16:16:46 +02:00
.map_err(|_| {
2023-08-09 15:42:29 +02:00
errors.push_global("Bad credentials");
payload.clone()
2023-08-03 16:16:46 +02:00
})?;
2023-08-09 15:42:29 +02:00
let Ok(bin_value) = bincode::serialize(&jwt_claims) else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload.clone());
};
{
use redis::AsyncCommands;
let Ok(mut conn) = redis.get().await else {
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
};
if let Err(e) = conn
.set::<'_, _, _, String>(jwt_claims.jwt_id.as_bytes(), bin_value)
.await
{
tracing::warn!("Failed to set sign-in claims in redis: {e}");
errors.push_global("Failed to sign in. Please try later");
return Err(payload);
}
}
2023-08-03 16:16:46 +02:00
Ok(LoginResponse {
2023-08-01 16:29:03 +02:00
bearer_token: jwt_token,
claims: jwt_claims,
2023-08-03 16:16:46 +02:00
})
2023-08-01 16:29:03 +02:00
}
2023-08-03 16:16:46 +02:00
#[autometrics]
2023-08-01 16:29:03 +02:00
#[get("/session")]
2023-08-01 22:06:04 +02:00
async fn session_info(authenticated: UserSession) -> Result<HttpResponse, Error> {
2023-08-01 16:29:03 +02:00
Ok(HttpResponse::Ok().json(authenticated))
}
2023-08-03 16:16:46 +02:00
#[autometrics]
2023-08-01 16:29:03 +02:00
#[get("/logout")]
async fn logout(
invalidated_jwts: Data<InvalidatedJWTStore>,
2023-08-01 22:06:04 +02:00
authenticated: Authenticated<Claims>,
2023-08-01 16:29:03 +02:00
) -> Result<HttpResponse, Error> {
invalidated_jwts.add_to_invalidated(authenticated).await;
Ok(HttpResponse::Ok().json(EmptyResponse {}))
}
2023-08-01 22:06:04 +02:00
2023-08-05 22:20:23 +02:00
#[derive(Debug, Default, Serialize, Deserialize, Clone, Eq, PartialEq, garde::Validate)]
#[garde(context(RegisterContext))]
2023-08-02 08:56:53 +02:00
struct AccountInfo {
2023-08-05 22:20:23 +02:00
#[garde(length(min = 4, max = 30), custom(is_login_free))]
#[serde(rename = "login")]
input_login: String,
#[garde(email, custom(is_email_free))]
2023-08-04 22:39:04 +02:00
email: String,
2023-08-05 22:20:23 +02:00
#[garde(length(min = 8, max = 50), custom(is_strong_password))]
2023-08-02 08:56:53 +02:00
password: String,
2023-08-04 22:39:04 +02:00
}
#[derive(askama_actix::Template)]
#[template(path = "./register/full.html")]
struct RegisterTemplate {
form: AccountInfo,
t: Arc<TranslationStorage>,
lang: Lang,
2023-08-05 22:20:23 +02:00
errors: oswilno_view::Errors,
2023-08-04 22:39:04 +02:00
}
#[get("/register")]
async fn register_view(t: Data<TranslationStorage>) -> RegisterTemplate {
RegisterTemplate {
form: AccountInfo::default(),
t: t.into_inner(),
lang: Lang::Pl,
2023-08-05 22:20:23 +02:00
errors: oswilno_view::Errors::default(),
2023-08-04 22:39:04 +02:00
}
}
2023-08-05 14:48:13 +02:00
#[get("/p/register")]
async fn register_partial_view(t: Data<TranslationStorage>) -> RegisterTemplate {
RegisterTemplate {
form: AccountInfo::default(),
t: t.into_inner(),
lang: Lang::Pl,
2023-08-05 22:20:23 +02:00
errors: oswilno_view::Errors::default(),
2023-08-05 14:48:13 +02:00
}
}
2023-08-04 22:39:04 +02:00
#[derive(askama_actix::Template)]
#[template(path = "./register/partial.html")]
struct RegisterPartialTemplate {
form: AccountInfo,
t: Arc<TranslationStorage>,
lang: Lang,
2023-08-05 22:20:23 +02:00
errors: oswilno_view::Errors,
2023-08-02 08:56:53 +02:00
}
2023-08-03 16:16:46 +02:00
#[autometrics]
2023-08-02 08:56:53 +02:00
#[post("/register")]
async fn register(
db: Data<DatabaseConnection>,
2023-08-04 16:32:10 +02:00
payload: Form<AccountInfo>,
2023-08-04 22:39:04 +02:00
t: Data<oswilno_view::TranslationStorage>,
2023-08-06 22:14:03 +02:00
lang: Lang,
2023-08-02 08:56:53 +02:00
) -> Result<HttpResponse, Error> {
2023-08-04 22:39:04 +02:00
let t = t.into_inner();
2023-08-05 22:20:23 +02:00
let mut errors = oswilno_view::Errors::default();
2023-08-04 22:39:04 +02:00
Ok(
2023-08-05 22:20:23 +02:00
match register_internal(db.into_inner(), payload.into_inner(), &mut errors).await {
2023-08-04 22:39:04 +02:00
Ok(res) => res,
Err(p) => HttpResponse::BadRequest().body(
RegisterPartialTemplate {
form: p,
t,
2023-08-06 22:14:03 +02:00
lang,
2023-08-05 22:20:23 +02:00
errors,
2023-08-04 22:39:04 +02:00
}
.render()
.unwrap(),
),
},
)
}
2023-08-05 22:20:23 +02:00
struct RegisterContext {
login_taken: bool,
email_taken: bool,
}
fn is_email_free(_value: &str, context: &RegisterContext) -> garde::Result {
if context.email_taken {
return Err(garde::Error::new("is taken"));
}
Ok(())
}
fn is_login_free(_value: &str, context: &RegisterContext) -> garde::Result {
if context.login_taken {
return Err(garde::Error::new("is taken"));
}
Ok(())
}
static WEAK_PASS: &str = "is not strong enough";
fn is_strong_password(value: &str, _context: &RegisterContext) -> garde::Result {
if !(8..50).contains(&value.len()) {
return Err(garde::Error::new(WEAK_PASS));
}
let mut num = false;
let mut low = false;
let mut up = false;
let mut spec = false;
for c in value.chars() {
if num && low && up && spec {
return Ok(());
}
num = num || c.is_numeric();
low = low || c.is_lowercase();
up = up || c.is_uppercase();
spec = spec || !c.is_alphanumeric();
}
return Err(garde::Error::new(WEAK_PASS));
}
2023-08-04 22:39:04 +02:00
async fn register_internal(
db: Arc<DatabaseConnection>,
2023-08-05 22:20:23 +02:00
p: AccountInfo,
errors: &mut oswilno_view::Errors,
2023-08-04 22:39:04 +02:00
) -> Result<HttpResponse, AccountInfo> {
2023-08-02 08:56:53 +02:00
use oswilno_contract::accounts::*;
2023-08-02 12:45:29 +02:00
use sea_orm::*;
2023-08-02 08:56:53 +02:00
2023-08-05 22:20:23 +02:00
let query_result = db
.query_one(sea_orm::Statement::from_sql_and_values(
sea_orm::DbBackend::Postgres,
"select login = $1 as login_taken, email = $2 as email_taken from accounts",
[p.input_login.clone().into(), p.email.clone().into()],
))
.await
.map_err(|e| {
tracing::error!("{e}");
errors.push_global("Something went wrong");
p.clone()
})?;
let (login_taken, email_taken) = if let Some(query_result) = query_result {
let Ok((login_taken, email_taken)): Result<(bool,bool), _> = query_result.try_get_many("", &["login_taken".into(), "email_taken".into()]) else {
tracing::warn!("Failed to fetch fields from query result while checking if account info exists in db");
errors.push_global("Something went wrong");
return Err(p);
};
(login_taken, email_taken)
} else {
(false, false)
};
if let Err(e) = p.validate(&RegisterContext {
login_taken,
email_taken,
}) {
errors.consume_garde(e);
2023-08-06 22:14:03 +02:00
return Err(p);
2023-08-05 22:20:23 +02:00
}
tracing::warn!("{errors:#?}");
2023-08-02 12:45:29 +02:00
let pass = match hashing::encrypt(p.password.as_str()) {
Ok(p) => p,
Err(e) => {
tracing::warn!("{e}");
return Ok(HttpResponse::InternalServerError().body(""));
}
};
2023-08-06 22:14:03 +02:00
let model = (ActiveModel {
2023-08-02 12:45:29 +02:00
id: NotSet,
2023-08-05 22:20:23 +02:00
login: Set(p.input_login.to_string()),
2023-08-05 14:48:13 +02:00
email: Set(p.email.to_string()),
2023-08-02 12:45:29 +02:00
pass_hash: Set(pass),
..Default::default()
})
.save(&*db)
.await
2023-08-06 22:14:03 +02:00
.map_err(|e| {
tracing::warn!("{e}");
errors.push_global("Login or email already taken");
p
})?;
2023-08-02 12:45:29 +02:00
tracing::info!("{model:?}");
Ok(HttpResponse::SeeOther()
2023-08-02 16:37:03 +02:00
.append_header(("Location", "/login"))
2023-08-02 12:45:29 +02:00
.json(EmptyResponse {}))
2023-08-02 08:56:53 +02:00
}
2023-08-01 22:06:04 +02:00
#[derive(Clone)]
struct InvalidatedJWTStore {
2023-08-02 16:37:03 +02:00
// store: Arc<DashSet<JWT>>,
2023-08-09 15:42:29 +02:00
redis: redis_async_pool::RedisPool,
2023-08-01 22:06:04 +02:00
tx: Arc<Mutex<Sender<InvalidatedTokensEvent>>>,
}
impl InvalidatedJWTStore {
/// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s
2023-08-09 15:42:29 +02:00
fn new_with_stream(
redis: redis_async_pool::RedisPool,
) -> (
2023-08-01 22:06:04 +02:00
InvalidatedJWTStore,
impl Stream<Item = InvalidatedTokensEvent>,
) {
2023-08-02 16:37:03 +02:00
// let invalidated = Arc::new(DashSet::new());
2023-08-01 22:06:04 +02:00
let (tx, rx) = mpsc::channel(100);
let tx_to_hold = Arc::new(Mutex::new(tx));
(
InvalidatedJWTStore {
2023-08-02 16:37:03 +02:00
// store: invalidated,
2023-08-09 15:42:29 +02:00
redis,
2023-08-01 22:06:04 +02:00
tx: tx_to_hold,
},
rx,
)
}
async fn add_to_invalidated(&self, authenticated: Authenticated<Claims>) {
2023-08-02 16:37:03 +02:00
// self.store.insert(authenticated.jwt.clone());
2023-08-01 22:06:04 +02:00
let mut tx = self.tx.lock().await;
if let Err(_e) = tx
.send(InvalidatedTokensEvent::Add(authenticated.jwt))
.await
{
#[cfg(feature = "tracing")]
error!(error = ?_e, "Failed to send update on adding to invalidated")
}
}
}
struct JwtSigningKeys {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl JwtSigningKeys {
fn generate() -> Result<Self, Box<dyn std::error::Error>> {
let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
Ok(JwtSigningKeys {
encoding_key,
decoding_key,
})
}
}