oswilno/crates/oswilno-session/src/lib.rs
2023-08-09 15:52:33 +02:00

353 lines
9.8 KiB
Rust

use std::ops::Add;
use std::sync::Arc;
use actix_jwt_authc::*;
use actix_web::web::{Data, Json, ServiceConfig, Query, Form};
use actix_web::{get, post, HttpResponse};
use autometrics::autometrics;
use futures::channel::{mpsc, mpsc::Sender};
use futures::stream::Stream;
use futures::SinkExt;
use jsonwebtoken::*;
use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use time::ext::*;
use time::OffsetDateTime;
use tokio::sync::Mutex;
use askama_actix::Template;
mod hashing;
pub type UserSession = Authenticated<Claims>;
#[derive(Clone, Copy)]
pub struct JWTTtl(time::Duration);
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
enum Audience {
Web,
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct Claims {
#[serde(rename = "exp")]
expires_at: usize,
#[serde(rename = "iat")]
issues_at: usize,
#[serde(rename = "sub")]
subject: String,
#[serde(rename = "aud")]
audience: Audience,
#[serde(rename = "jti")]
jwt_id: uuid::Uuid,
}
#[derive(Serialize, Deserialize)]
pub struct EmptyResponse {}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
bearer_token: String,
claims: Claims,
}
const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
#[derive(Clone)]
pub struct SessionConfigurator {
jwt_ttl: Data<JWTTtl>,
invalidated_jwts_store: Data<InvalidatedJWTStore>,
encoding_key: Data<EncodingKey>,
factory: AuthenticateMiddlewareFactory<Claims>,
}
impl SessionConfigurator {
pub fn app_data(self, config: &mut ServiceConfig) {
config
.app_data(self.invalidated_jwts_store)
.app_data(self.encoding_key)
.app_data(self.jwt_ttl)
.service(login)
.service(login_view)
.service(logout)
.service(session_info);
}
pub fn factory(&self) -> AuthenticateMiddlewareFactory<Claims> {
self.factory.clone()
}
pub fn translations(&self, l10n: &mut oswilno_view::TranslationStorage) {
l10n
// English
.with_lang("en")
.add("Sign in", "Sign in")
.add("Sign up", "Sign up")
.add("Bad credentials", "Bad credentials")
.done()
// Polish
.with_lang("pl")
.add("Sign in", "Logowanie")
.add("Sign up", "Rejestracja")
.add("Bad credentials", "Złe dane uwierzytelniające")
.done();
}
pub fn new() -> Self {
let jwt_ttl = JWTTtl(31.days());
let jwt_signing_keys = JwtSigningKeys::generate().unwrap();
let validator = Validation::new(JWT_SIGNING_ALGO);
let auth_middleware_settings = AuthenticateMiddlewareSettings {
jwt_decoding_key: jwt_signing_keys.decoding_key,
jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]),
jwt_validator: validator,
jwt_session_key: None,
};
let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream();
let auth_middleware_factory =
AuthenticateMiddlewareFactory::<Claims>::new(stream, auth_middleware_settings.clone());
Self {
invalidated_jwts_store: Data::new(invalidated_jwts_store.clone()),
encoding_key: Data::new(jwt_signing_keys.encoding_key.clone()),
jwt_ttl: Data::new(jwt_ttl.clone()),
factory: auth_middleware_factory,
}
}
}
#[derive(Template)]
#[template(path = "./sign-in/full.html")]
struct SignInTemplate {
form: SignInPayload,
}
#[derive(Template)]
#[template(path = "./sign-in/partial.html")]
struct SignInPartialTemplate {
form: SignInPayload,
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq, Default)]
pub struct SignInPayload {
#[serde(skip, default)]
errors: Vec<String>,
login: String,
#[serde(skip, default)]
login_errors: Vec<String>,
password: String,
#[serde(skip, default)]
password_errors: Vec<String>,
}
#[get("/login")]
async fn login_view() -> SignInTemplate {
SignInTemplate { form: SignInPayload::default() }
}
#[autometrics]
#[post("/login")]
async fn login(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Data<JWTTtl>,
db: Data<DatabaseConnection>,
payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>,
) -> Result<HttpResponse, Error> {
match login_inner(jwt_encoding_key, jwt_ttl, db, payload, t).await {
Ok(res) => Ok(HttpResponse::Ok().json(res)),
Err(form) => Ok(HttpResponse::Ok().body(SignInPartialTemplate { form }.render().unwrap())),
}
}
async fn login_inner(
jwt_encoding_key: Data<EncodingKey>,
jwt_ttl: Data<JWTTtl>,
db: Data<DatabaseConnection>,
payload: Form<SignInPayload>,
t: Data<oswilno_view::TranslationStorage>,
) -> Result<LoginResponse, SignInPayload> {
let db = db.into_inner();
let mut payload = payload.into_inner();
let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0);
let exp = expires_at.unix_timestamp() as usize;
use sea_orm::*;
let account = match oswilno_contract::accounts::Entity::find()
.filter(oswilno_contract::accounts::Column::Login.eq(payload.login.as_str()))
.one(&*db)
.await
{
Ok(Some(a)) => a,
Ok(None) => {
payload.errors.push(t.to_lang("pl", "Bad credentials"));
return Err(payload);
}
Err(e) => {
tracing::warn!("Failed to find account: {e}");
payload.errors.push(t.to_lang("pl", "Bad credentials"));
return Err(payload);
}
};
if hashing::verify(account.pass_hash.as_str(), payload.password.as_str()).is_err() {
payload.errors.push(t.to_lang("pl", "Bad credentials"));
return Err(payload);
}
let jwt_claims = Claims {
issues_at: iat,
subject: format!("account-{}", account.id),
expires_at: exp,
audience: Audience::Web,
jwt_id: uuid::Uuid::new_v4(),
};
let jwt_token = encode(
&Header::new(JWT_SIGNING_ALGO),
&jwt_claims,
&jwt_encoding_key,
)
.map_err(|_| {
payload.errors.push(t.to_lang("pl", "Bad credentials"));
payload
})?;
Ok(LoginResponse {
bearer_token: jwt_token,
claims: jwt_claims,
})
}
#[autometrics]
#[get("/session")]
async fn session_info(authenticated: UserSession) -> Result<HttpResponse, Error> {
Ok(HttpResponse::Ok().json(authenticated))
}
#[autometrics]
#[get("/logout")]
async fn logout(
invalidated_jwts: Data<InvalidatedJWTStore>,
authenticated: Authenticated<Claims>,
) -> Result<HttpResponse, Error> {
invalidated_jwts.add_to_invalidated(authenticated).await;
Ok(HttpResponse::Ok().json(EmptyResponse {}))
}
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
struct AccountInfo {
login: String,
password: String,
}
#[autometrics]
#[post("/register")]
async fn register(
db: Data<DatabaseConnection>,
payload: Query<AccountInfo>,
) -> Result<HttpResponse, Error> {
use oswilno_contract::accounts::*;
use sea_orm::ActiveValue::NotSet;
use sea_orm::*;
let db = db.into_inner();
let p = payload.into_inner();
let pass = match hashing::encrypt(p.password.as_str()) {
Ok(p) => p,
Err(e) => {
tracing::warn!("{e}");
return Ok(HttpResponse::InternalServerError().body(""));
}
};
match Entity::find()
.filter(Column::Login.eq(&p.login))
.one(&*db)
.await
{
Ok(None) | Err(_) => return Ok(HttpResponse::Conflict().body("")),
_ => (),
};
let model = match (ActiveModel {
id: NotSet,
login: Set(p.login),
pass_hash: Set(pass),
..Default::default()
})
.save(&*db)
.await
{
Ok(model) => model,
Err(e) => {
tracing::warn!("{e}");
return Ok(HttpResponse::Conflict().body(""));
}
};
tracing::info!("{model:?}");
Ok(HttpResponse::SeeOther()
.append_header(("Location", "/login"))
.json(EmptyResponse {}))
}
#[derive(Clone)]
struct InvalidatedJWTStore {
// store: Arc<DashSet<JWT>>,
tx: Arc<Mutex<Sender<InvalidatedTokensEvent>>>,
}
impl InvalidatedJWTStore {
/// Returns a [InvalidatedJWTStore] with a Stream of [InvalidatedTokensEvent]s
fn new_with_stream() -> (
InvalidatedJWTStore,
impl Stream<Item = InvalidatedTokensEvent>,
) {
// let invalidated = Arc::new(DashSet::new());
let (tx, rx) = mpsc::channel(100);
let tx_to_hold = Arc::new(Mutex::new(tx));
(
InvalidatedJWTStore {
// store: invalidated,
tx: tx_to_hold,
},
rx,
)
}
async fn add_to_invalidated(&self, authenticated: Authenticated<Claims>) {
// self.store.insert(authenticated.jwt.clone());
let mut tx = self.tx.lock().await;
if let Err(_e) = tx
.send(InvalidatedTokensEvent::Add(authenticated.jwt))
.await
{
#[cfg(feature = "tracing")]
error!(error = ?_e, "Failed to send update on adding to invalidated")
}
}
}
struct JwtSigningKeys {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl JwtSigningKeys {
fn generate() -> Result<Self, Box<dyn std::error::Error>> {
let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
Ok(JwtSigningKeys {
encoding_key,
decoding_key,
})
}
}