use std::collections::BTreeMap; use std::str::FromStr; use std::sync::Arc; use actix::{Addr, Message}; use chrono::prelude::*; use hmac::digest::KeyInit; use hmac::Hmac; use sha2::Sha256; use crate::database::{Database, TokenByJti}; use crate::model::{AccountId, Audience, Token, TokenString}; use crate::{database, token_async_handler, Role}; struct Jwt { /// cti (customer id): Customer uuid identifier used by payment service pub cti: uuid::Uuid, /// arl (account role): account role pub arl: Role, /// iss (issuer): Issuer of the JWT pub iss: String, /// sub (subject): Subject of the JWT (the user) pub sub: i32, /// aud (audience): Recipient for which the JWT is intended pub aud: Audience, /// exp (expiration time): Time after which the JWT expires pub exp: chrono::NaiveDateTime, /// nbt (not before time): Time before which the JWT must not be accepted /// for processing pub nbt: chrono::NaiveDateTime, /// iat (issued at time): Time at which the JWT was issued; can be used to /// determine age of the JWT, pub iat: chrono::NaiveDateTime, /// jti (JWT ID): Unique identifier; can be used to prevent the JWT from /// being replayed (allows a token to be used only once) pub jti: uuid::Uuid, } #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Unable to save new token")] Save, #[error("Unable to save new token. Can't connect to database")] SaveInternal, #[error("Unable to validate token")] Validate, #[error("Unable to validate token. Can't connect to database")] ValidateInternal, } pub type Result = std::result::Result; pub struct TokenManager { db: Addr, secret: Arc, } impl actix::Actor for TokenManager { type Context = actix::Context; } impl TokenManager { pub fn new(db: Addr) -> Self { let secret = Arc::new(std::env::var("JWT_SECRET").expect("JWT_SECRET is required")); Self { db, secret } } } #[derive(Message)] #[rtype(result = "Result<(Token, TokenString)>")] pub struct CreateToken { pub customer_id: uuid::Uuid, pub role: Role, pub subject: AccountId, pub audience: Option, } token_async_handler!(CreateToken, create_token, (Token, TokenString)); pub(crate) async fn create_token( msg: CreateToken, db: Addr, secret: Arc, ) -> Result<(Token, TokenString)> { let CreateToken { customer_id, role, subject, audience, } = msg; let audience = audience.unwrap_or_default(); let token: Token = match db .send(database::CreateToken { customer_id, role, subject, audience, }) .await { Ok(Ok(token)) => token, Ok(Err(db_err)) => { log::error!("{db_err}"); return Err(Error::Save); } Err(act_err) => { log::error!("{act_err:?}"); return Err(Error::SaveInternal); } }; let token_string = { use jwt::SignWithKey; let key: Hmac = build_key(secret)?; let mut claims = BTreeMap::new(); // cti (customer id): Customer uuid identifier used by payment service claims.insert("cti", format!("{}", token.customer_id)); // arl (account role): account role claims.insert("arl", format!("{}", token.role.as_str())); // iss (issuer): Issuer of the JWT claims.insert("iss", format!("{}", token.issuer)); // sub (subject): Subject of the JWT (the user) claims.insert("sub", format!("{}", token.subject)); // aud (audience): Recipient for which the JWT is intended claims.insert("aud", format!("{}", token.audience.as_str())); // exp (expiration time): Time after which the JWT expires claims.insert("exp", format!("{}", token.expiration_time.format("%+"))); // nbt (not before time): Time before which the JWT must not be accepted // for processing claims.insert("nbt", format!("{}", token.not_before_time.format("%+"))); // iat (issued at time): Time at which the JWT was issued; can be used // to determine age of the JWT, claims.insert("iat", format!("{}", token.issued_at_time.format("%+"))); // jti (JWT ID): Unique identifier; can be used to prevent the JWT from // being replayed (allows a token to be used only once) claims.insert("jti", format!("{}", token.jwt_id)); TokenString::from(match claims.sign_with_key(&key) { Ok(s) => s, Err(e) => { log::error!("{e:?}"); return Err(Error::SaveInternal); } }) }; Ok((token, token_string)) } #[derive(Message)] #[rtype(result = "Result<(Token, bool)>")] pub struct Validate { pub token: TokenString, } token_async_handler!(Validate, validate, (Token, bool)); pub(crate) async fn validate( msg: Validate, db: Addr, secret: Arc, ) -> Result<(Token, bool)> { use jwt::VerifyWithKey; log::info!("Validating token {:?}", msg.token); let key: Hmac = build_key(secret)?; let claims: BTreeMap = match msg.token.verify_with_key(&key) { Ok(claims) => claims, _ => return Err(Error::Validate), }; let jti = match claims.get("jti") { Some(jti) => jti, _ => return Err(Error::Validate), }; let token: Token = match db .send(TokenByJti { jti: String::from(jti), }) .await { Ok(Ok(token)) => token, Ok(Err(e)) => { log::error!("{e}"); return Err(Error::Validate); } Err(e) => { log::error!("{e:?}"); return Err(Error::ValidateInternal); } }; if !validate_pair(&claims, "cti", token.customer_id, validate_uuid) { return Ok((token, false)); } // if !validate_pair(&claims, "arl", token.role, |left, right| right == left) { // return Ok((token, false)); // } match (claims.get("arl"), &token.role) { (Some(arl), role) if role == arl.as_str() => {} _ => return Ok((token, false)), } match (claims.get("iss"), &token.issuer) { (Some(iss), issuer) if iss == issuer => {} _ => return Ok((token, false)), } if !validate_pair(&claims, "sub", token.subject, validate_num) { return Ok((token, false)); } match (claims.get("aud"), &token.audience) { (Some(aud), audience) if aud == audience.as_str() => {} _ => return Ok((token, false)), } if !validate_pair(&claims, "exp", &token.expiration_time, validate_time) { return Ok((token, false)); } if !validate_pair(&claims, "nbt", &token.not_before_time, validate_time) { return Ok((token, false)); } if !validate_pair(&claims, "iat", &token.issued_at_time, validate_time) { return Ok((token, false)); } Ok((token, true)) } fn build_key(secret: Arc) -> Result> { match Hmac::new_from_slice(secret.as_bytes()) { Ok(key) => Ok(key), Err(e) => { log::error!("{e:?}"); Err(Error::ValidateInternal) } } } fn validate_pair(claims: &BTreeMap, key: &str, v: V, cmp: F) -> bool where F: FnOnce(&str, V) -> bool, V: PartialEq, { claims.get(key).map(|s| cmp(s, v)).unwrap_or_default() } fn validate_time(left: &str, right: &NaiveDateTime) -> bool { chrono::DateTime::parse_from_str(left, "%+") .map(|t| t.naive_utc() == *right) .unwrap_or_default() } fn validate_num(left: &str, right: i32) -> bool { left.parse::().map(|n| n == right).unwrap_or_default() } fn validate_uuid(left: &str, right: uuid::Uuid) -> bool { uuid::Uuid::from_str(left) .map(|u| u == right) .unwrap_or_default() }