From 1667270d73b4841e47ccbbe31186e71ea81f0a8c Mon Sep 17 00:00:00 2001 From: eraden Date: Tue, 19 Apr 2022 08:04:40 +0200 Subject: [PATCH] Validate token --- api/src/actors/token_manager.rs | 105 ++++++++++++++++---------------- api/src/model.rs | 8 ++- 2 files changed, 58 insertions(+), 55 deletions(-) diff --git a/api/src/actors/token_manager.rs b/api/src/actors/token_manager.rs index ee74073..60f9908 100644 --- a/api/src/actors/token_manager.rs +++ b/api/src/actors/token_manager.rs @@ -1,8 +1,12 @@ +use std::collections::BTreeMap; use std::str::FromStr; use std::sync::Arc; use actix::{Addr, Message}; use chrono::prelude::*; +use hmac::digest::KeyInit; +use hmac::Hmac; +use sha2::Sha256; use crate::database::{Database, TokenByJti}; use crate::model::{AccountId, Audience, Token, TokenString}; @@ -73,7 +77,7 @@ pub struct CreateToken { token_async_handler!(CreateToken, create_token, (Token, TokenString)); -async fn create_token( +pub(crate) async fn create_token( msg: CreateToken, db: Addr, secret: Arc, @@ -107,19 +111,9 @@ async fn create_token( }; let token_string = { - use std::collections::BTreeMap; - - use hmac::{Hmac, Mac}; use jwt::SignWithKey; - use sha2::Sha256; - let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { - Ok(key) => key, - Err(e) => { - log::error!("{e:?}"); - return Err(Error::SaveInternal); - } - }; + let key: Hmac = build_key(secret)?; let mut claims = BTreeMap::new(); // cti (customer id): Customer uuid identifier used by payment service @@ -168,21 +162,11 @@ pub(crate) async fn validate( db: Addr, secret: Arc, ) -> Result<(Token, bool)> { - use std::collections::BTreeMap; - - use hmac::{Hmac, Mac}; use jwt::VerifyWithKey; - use sha2::Sha256; log::info!("Validating token {:?}", msg.token); - let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { - Ok(key) => key, - Err(e) => { - log::error!("{e:?}"); - return Err(Error::ValidateInternal); - } - }; + let key: Hmac = build_key(secret)?; let claims: BTreeMap = match msg.token.verify_with_key(&key) { Ok(claims) => claims, _ => return Err(Error::Validate), @@ -209,59 +193,72 @@ pub(crate) async fn validate( } }; - match (claims.get("cti"), &token.customer_id) { - (Some(cti), id) => { - if !uuid::Uuid::from_str(cti) - .map(|u| u == *id) - .unwrap_or_default() - { - return Ok((token, false)); - } - } - _ => return Ok((token, false)), + if !validate_pair(&claims, "cti", token.customer_id, validate_uuid) { + return Ok((token, false)); } + // if !validate_pair(&claims, "arl", token.role, |left, right| right == left) { + // return Ok((token, false)); + // } match (claims.get("arl"), &token.role) { - (Some(arl), role) if arl == role.as_str() => {} + (Some(arl), role) if role == arl.as_str() => {} _ => return Ok((token, false)), } match (claims.get("iss"), &token.issuer) { (Some(iss), issuer) if iss == issuer => {} _ => return Ok((token, false)), } - match (claims.get("sub"), &token.subject) { - (Some(sub), subject) => { - if !sub - .parse::() - .map(|n| n == *subject) - .unwrap_or_default() - { - return Ok((token, false)); - } - } - _ => return Ok((token, false)), + if !validate_pair(&claims, "sub", token.subject, validate_num) { + return Ok((token, false)); } + match (claims.get("aud"), &token.audience) { (Some(aud), audience) if aud == audience.as_str() => {} _ => return Ok((token, false)), } - match (claims.get("exp"), &token.expiration_time) { - (Some(left), right) if validate_time(left, right) => {} - _ => return Ok((token, false)), + + if !validate_pair(&claims, "exp", &token.expiration_time, validate_time) { + return Ok((token, false)); } - match (claims.get("nbt"), &token.not_before_time) { - (Some(left), right) if validate_time(left, right) => {} - _ => return Ok((token, false)), + if !validate_pair(&claims, "nbt", &token.not_before_time, validate_time) { + return Ok((token, false)); } - match (claims.get("iat"), &token.issued_at_time) { - (Some(left), right) if validate_time(left, right) => {} - _ => return Ok((token, false)), + if !validate_pair(&claims, "iat", &token.issued_at_time, validate_time) { + return Ok((token, false)); } Ok((token, true)) } +fn build_key(secret: Arc) -> Result> { + match Hmac::new_from_slice(secret.as_bytes()) { + Ok(key) => Ok(key), + Err(e) => { + log::error!("{e:?}"); + Err(Error::ValidateInternal) + } + } +} + +fn validate_pair(claims: &BTreeMap, key: &str, v: V, cmp: F) -> bool +where + F: FnOnce(&str, V) -> bool, + V: PartialEq, +{ + claims.get(key).map(|s| cmp(s, v)).unwrap_or_default() +} + fn validate_time(left: &str, right: &NaiveDateTime) -> bool { chrono::DateTime::parse_from_str(left, "%+") .map(|t| t.naive_utc() == *right) .unwrap_or_default() } + +fn validate_num(left: &str, right: i32) -> bool { + left.parse::().map(|n| n == right).unwrap_or_default() +} + +fn validate_uuid(left: &str, right: uuid::Uuid) -> bool { + uuid::Uuid::from_str(left) + .map(|u| u == right) + .unwrap_or_default() +} diff --git a/api/src/model.rs b/api/src/model.rs index 308f1f3..37b2054 100644 --- a/api/src/model.rs +++ b/api/src/model.rs @@ -31,7 +31,7 @@ pub enum OrderStatus { Refunded, } -#[derive(sqlx::Type, Copy, Clone, Debug, Display, Deserialize, Serialize)] +#[derive(sqlx::Type, Copy, Clone, Debug, Display, Deserialize, Serialize, PartialEq)] #[sqlx(rename_all = "snake_case")] pub enum Role { #[display(fmt = "Adminitrator")] @@ -40,6 +40,12 @@ pub enum Role { User, } +impl PartialEq for Role { + fn eq(&self, other: &str) -> bool { + self.as_str() == other + } +} + impl Role { pub fn as_str(&self) -> &str { match self {