use std::str::FromStr; use std::sync::Arc; use actix::{Addr, Message}; use chrono::prelude::*; use crate::database::{Database, TokenByJti}; use crate::model::{AccountId, Audience, Token, TokenString}; use crate::{database, token_async_handler, Role}; struct Jwt { /// cti (customer id): Customer uuid identifier used by payment service pub cti: uuid::Uuid, /// arl (account role): account role pub arl: Role, /// iss (issuer): Issuer of the JWT pub iss: String, /// sub (subject): Subject of the JWT (the user) pub sub: i32, /// aud (audience): Recipient for which the JWT is intended pub aud: Audience, /// exp (expiration time): Time after which the JWT expires pub exp: chrono::NaiveDateTime, /// nbt (not before time): Time before which the JWT must not be accepted /// for processing pub nbt: chrono::NaiveDateTime, /// iat (issued at time): Time at which the JWT was issued; can be used to /// determine age of the JWT, pub iat: chrono::NaiveDateTime, /// jti (JWT ID): Unique identifier; can be used to prevent the JWT from /// being replayed (allows a token to be used only once) pub jti: uuid::Uuid, } #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Unable to save new token")] Save, #[error("Unable to save new token. Can't connect to database")] SaveInternal, #[error("Unable to validate token")] Validate, #[error("Unable to validate token. Can't connect to database")] ValidateInternal, } pub type Result = std::result::Result; pub struct TokenManager { db: Addr, secret: Arc, } impl actix::Actor for TokenManager { type Context = actix::Context; } impl TokenManager { pub fn new(db: Addr) -> Self { let secret = Arc::new(std::env::var("JWT_SECRET").expect("JWT_SECRET is required")); Self { db, secret } } } #[derive(Message)] #[rtype(result = "Result<(Token, TokenString)>")] pub struct CreateToken { pub customer_id: uuid::Uuid, pub role: Role, pub subject: AccountId, pub audience: Option, } token_async_handler!(CreateToken, create_token, (Token, TokenString)); async fn create_token( msg: CreateToken, db: Addr, secret: Arc, ) -> Result<(Token, TokenString)> { let CreateToken { customer_id, role, subject, audience, } = msg; let audience = audience.unwrap_or_default(); let token: Token = match db .send(database::CreateToken { customer_id, role, subject, audience, }) .await { Ok(Ok(token)) => token, Ok(Err(db_err)) => { log::error!("{db_err}"); return Err(Error::Save); } Err(act_err) => { log::error!("{act_err:?}"); return Err(Error::SaveInternal); } }; let token_string = { use std::collections::BTreeMap; use hmac::{Hmac, Mac}; use jwt::SignWithKey; use sha2::Sha256; let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { Ok(key) => key, Err(e) => { log::error!("{e:?}"); return Err(Error::SaveInternal); } }; let mut claims = BTreeMap::new(); // cti (customer id): Customer uuid identifier used by payment service claims.insert("cti", format!("{}", token.customer_id)); // arl (account role): account role claims.insert("arl", format!("{}", token.role.as_str())); // iss (issuer): Issuer of the JWT claims.insert("iss", format!("{}", token.issuer)); // sub (subject): Subject of the JWT (the user) claims.insert("sub", format!("{}", token.subject)); // aud (audience): Recipient for which the JWT is intended claims.insert("aud", format!("{}", token.audience.as_str())); // exp (expiration time): Time after which the JWT expires claims.insert("exp", format!("{}", token.expiration_time.format("%+"))); // nbt (not before time): Time before which the JWT must not be accepted // for processing claims.insert("nbt", format!("{}", token.not_before_time.format("%+"))); // iat (issued at time): Time at which the JWT was issued; can be used // to determine age of the JWT, claims.insert("iat", format!("{}", token.issued_at_time.format("%+"))); // jti (JWT ID): Unique identifier; can be used to prevent the JWT from // being replayed (allows a token to be used only once) claims.insert("jti", format!("{}", token.jwt_id)); TokenString::from(match claims.sign_with_key(&key) { Ok(s) => s, Err(e) => { log::error!("{e:?}"); return Err(Error::SaveInternal); } }) }; Ok((token, token_string)) } #[derive(Message)] #[rtype(result = "Result<(Token, bool)>")] pub struct Validate { pub token: TokenString, } token_async_handler!(Validate, validate, (Token, bool)); pub(crate) async fn validate( msg: Validate, db: Addr, secret: Arc, ) -> Result<(Token, bool)> { use std::collections::BTreeMap; use hmac::{Hmac, Mac}; use jwt::VerifyWithKey; use sha2::Sha256; log::info!("Validating token {:?}", msg.token); let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { Ok(key) => key, Err(e) => { log::error!("{e:?}"); return Err(Error::ValidateInternal); } }; let claims: BTreeMap = match msg.token.verify_with_key(&key) { Ok(claims) => claims, _ => return Err(Error::Validate), }; let jti = match claims.get("jti") { Some(jti) => jti, _ => return Err(Error::Validate), }; let token: Token = match db .send(TokenByJti { jti: String::from(jti), }) .await { Ok(Ok(token)) => token, Ok(Err(e)) => { log::error!("{e}"); return Err(Error::Validate); } Err(e) => { log::error!("{e:?}"); return Err(Error::ValidateInternal); } }; match (claims.get("cti"), &token.customer_id) { (Some(cti), id) => { if !uuid::Uuid::from_str(cti) .map(|u| u == *id) .unwrap_or_default() { return Ok((token, false)); } } _ => return Ok((token, false)), } match (claims.get("arl"), &token.role) { (Some(arl), role) if arl == role.as_str() => {} _ => return Ok((token, false)), } match (claims.get("iss"), &token.issuer) { (Some(iss), issuer) if iss == issuer => {} _ => return Ok((token, false)), } match (claims.get("sub"), &token.subject) { (Some(sub), subject) => { if !sub .parse::() .map(|n| n == *subject) .unwrap_or_default() { return Ok((token, false)); } } _ => return Ok((token, false)), } match (claims.get("aud"), &token.audience) { (Some(aud), audience) if aud == audience.as_str() => {} _ => return Ok((token, false)), } match (claims.get("exp"), &token.expiration_time) { (Some(left), right) if validate_time(left, right) => {} _ => return Ok((token, false)), } match (claims.get("nbt"), &token.not_before_time) { (Some(left), right) if validate_time(left, right) => {} _ => return Ok((token, false)), } match (claims.get("iat"), &token.issued_at_time) { (Some(left), right) if validate_time(left, right) => {} _ => return Ok((token, false)), } Ok((token, true)) } fn validate_time(left: &str, right: &NaiveDateTime) -> bool { chrono::DateTime::parse_from_str(left, "%+") .map(|t| t.naive_utc() == *right) .unwrap_or_default() }