This commit is contained in:
eraden 2023-12-27 13:36:07 +01:00
commit fa31b4ab74
21 changed files with 2643 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

1927
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

10
Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[workspace]
resolver = "2"
members = [
"crates/database",
"crates/events",
"crates/identity-agent",
"crates/router",
"crates/sessions-agent",
"crates/shared-config",
]

3
README.md Normal file
View File

@ -0,0 +1,3 @@
## Dependencies
* rumqttd

3
config/private.pem Normal file
View File

@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIHIQSS3xUjp+QFz6YqTA9G0s46QhBImxBNhbGTbV5beh
-----END PRIVATE KEY-----

3
config/public.pem Normal file
View File

@ -0,0 +1,3 @@
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEArWkJ/nxD4GpA3t0T/nMqXuWc+BZDHh1jWADs8woa2GI=
-----END PUBLIC KEY-----

126
config/rumqttd.toml Normal file
View File

@ -0,0 +1,126 @@
id = 0
# A commitlog read will pull full segment. Make sure that a segment isn't
# too big as async tcp writes readiness of one connection might affect tail
# latencies of other connection. Not a problem with preempting runtimes
[router]
dir = "./data/rumqttd"
id = 0
max_connections = 10010
max_outgoing_packet_count = 200
max_segment_size = 104857600
max_segment_count = 10
# shared_subscriptions_strategy = "random" # "sticky" | "roundrobin" ( default ) | "random"
# Any filters that match to configured filter will have custom segment size.
# [router.custom_segment.'/office/+/devices/status']
# max_segment_size = 102400
# max_segment_count = 2
# [router.custom_segment.'/home/+/devices/status']
# max_segment_size = 51200
# max_segment_count = 2
# [bridge]
# name = "bridge-1"
# addr = "localhost:1883"
# qos = 0
# sub_path = "#"
# reconnection_delay = 5
# ping_delay = 5
# timeout_delay = 5
# [bridge.connections]
# connection_timeout_ms = 60000
# max_payload_size = 20480
# max_inflight_count = 500
# dynamic_filters = true
# [bridge.transport.tls]
# ca = "ca.cert.pem"
# client_auth = { certs = "test-1.cert.pem", key = "test-1.key.pem" }
[servers]
# Configuration of server and connections that it accepts
[servers.v4]
name = "v4-1"
listen = "0.0.0.0:1883"
next_connection_delay_ms = 1
[servers.v4.connections]
connection_timeout_ms = 60000
max_payload_size = 20480
max_inflight_count = 100
dynamic_filters = true
max_client_id_len = 256
throttle_delay_ms = 0
max_inflight_size = 1024
# auth = { user1 = "p@ssw0rd", user2 = "password" }
# [v4.1.connections.auth]
# user1 = "p@ssw0rd"
# user2 = "password"
# [v4.2]
# name = "v4-2"
# listen = "0.0.0.0:8883"
# next_connection_delay_ms = 10
# # tls config for rustls
# [v4.2.tls]
# capath = "/etc/tls/ca.cert.pem"
# certpath = "/etc/tls/server.cert.pem"
# keypath = "/etc/tls/server.key.pem"
# # settings for all the connections on this server
# [v4.2.connections]
# connection_timeout_ms = 60000
# throttle_delay_ms = 0
# max_payload_size = 20480
# max_inflight_count = 100
# max_inflight_size = 1024
[servers.v5]
name = "v5-1"
listen = "0.0.0.0:1884"
next_connection_delay_ms = 1
[servers.v5.connections]
connection_timeout_ms = 60000
max_payload_size = 20480
max_inflight_count = 100
max_client_id_len = 256
throttle_delay_ms = 0
max_inflight_size = 1024
# [prometheus]
# listen = "127.0.0.1:9042"
# interval = 1
[servers.ws]
name = "ws-1"
listen = "0.0.0.0:8083"
next_connection_delay_ms = 1
[servers.ws.connections]
connection_timeout_ms = 60000
max_client_id_len = 256
throttle_delay_ms = 0
max_payload_size = 20480
max_inflight_count = 500
max_inflight_size = 1024
# [ws.2]
# name = "ws-2"
# listen = "0.0.0.0:8081"
# next_connection_delay_ms = 1
# [ws.2.tls]
# capath = "/etc/tls/ca.cert.pem"
# certpath = "/etc/tls/server.cert.pem"
# keypath = "/etc/tls/server.key.pem"
# [ws.2.connections]
# connection_timeout_ms = 60000
# max_client_id_len = 256
# throttle_delay_ms = 0
# max_payload_size = 20480
# max_inflight_count = 500
# max_inflight_size = 1024
[console]
listen = "0.0.0.0:3030"
# [metrics]
# [metrics.alerts]
# push_interval = 1
# [metrics.meters]
# push_interval = 1

View File

@ -0,0 +1,12 @@
[package]
name = "database"
version = "0.1.0"
edition = "2021"
[dependencies]
sea-orm = { version = "0", features = ["runtime-tokio-rustls", "sqlx-sqlite", "sqlx-postgres"] }
sea-orm-migration = { version = "0", features = ["runtime-tokio-rustls", "sqlx-sqlite", "sqlx-postgres"] }
migration = { path = "../migration" }
entities = { path = "../entities" }
chrono = { version = "0.4.31", default-features = false, features = ["serde", "clock", "libc"] }
uuid = { version = "1.6.1", features = ["v4"] }

View File

@ -0,0 +1,41 @@
pub use sea_orm;
pub use sea_orm_migration;
pub use chrono;
pub use sea_orm::prelude::*;
pub use uuid;
fn sqlite_file_path() -> String {
let file = std::env::current_dir()
.expect("No working directory")
.join("database.db?mode=rwc")
.display()
.to_string();
format!("sqlite:{file}")
}
pub trait DatabaseUrl {
fn database_url(&self) -> String {
self.provided_url()
.cloned()
.or_else(|| std::env::var("DATABASE_URL").ok())
.unwrap_or_else(sqlite_file_path)
}
fn provided_url(&self) -> Option<&String>;
}
pub async fn run_migration<Migrator: MigrationTrait>(opts: &impl DatabaseUrl) {
let connection = sea_orm::Database::connect(opts.database_url())
.await
.expect("Failed to connect to database");
Migrator::up(&connection, None)
.await
.expect("Failed to run migration");
}
pub async fn db_connect(opts: &impl DatabaseUrl) -> DatabaseConnection {
sea_orm::Database::connect(opts.database_url())
.await
.expect("Failed to connect to database")
}

19
crates/events/Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "events"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bincode = "1.3.3"
bytes = "1.5.0"
futures = { version = "0.3.29", default-features = false, features = ["async-await", "futures-executor", "thread-pool", "compat", "io-compat", "std"] }
futures-util = { version = "0.3.29", default-features = false, features = ["async-await", "futures-channel", "futures-io", "futures-macro"] }
rumqttc = "0.23.0"
serde = "1.0.193"
thiserror = "1.0.51"
tokio = { version = "1.35.1", default-features = false, features = ["full"] }
toml = { version = "0.8.8", default-features = false, features = ["parse", "indexmap", "preserve_order"] }
tracing = "0.1.40"
uuid = { version = "1.6.1", default-features = false, features = ["v4", "serde"] }

129
crates/events/src/lib.rs Normal file
View File

@ -0,0 +1,129 @@
use bytes::Bytes;
pub use rumqttc::QoS;
use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, Publish};
use std::time::Duration;
pub type RecordId = u32;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct UserId(RecordId);
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub enum UserEvent {
SignIn(UserId),
SignOut(UserId),
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub enum TokenEvent {
Invalidated
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub enum AppEvent {
User(UserEvent),
}
pub fn create_conn<S: Into<String>>(
device_id: S,
bind: Option<&str>,
port: Option<u16>,
) -> (AsyncClient, EventLoop) {
let mut mqttoptions =
MqttOptions::new(device_id, bind.unwrap_or("0.0.0.0"), port.unwrap_or(1883));
mqttoptions
.set_keep_alive(Duration::from_secs(5))
.set_manual_acks(true)
.set_clean_session(false);
AsyncClient::new(mqttoptions, 10)
}
pub struct Bus {
client: AsyncClient,
}
impl Bus {
pub fn new<Handler, Fut, S: Into<String>>(
event_handler: Handler,
device_id: S,
bind: Option<&str>,
port: Option<u16>,
) -> Self
where
Fut: std::future::Future + Send + 'static,
Handler: Fn(String, AppEvent) -> Fut + Send + 'static,
{
let (client, event_loop) = create_conn(device_id, bind, port);
tokio::spawn(run(event_handler, event_loop, client.clone()));
Self { client }
}
pub async fn publish(&self, topic: &str, qos: QoS, retain: bool, app_event: AppEvent) -> Result<(), ()> {
let bytes = bincode::serialize(&app_event).map_err(|_| ())?;
let payload = Bytes::from_iter(bytes.into_iter());
self.client.publish_bytes(topic, qos, retain, payload).await.map_err(|_| ())?;
Ok(())
}
}
pub async fn run<Handler, Fut>(
event_handler: Handler,
mut event_loop: EventLoop,
client: AsyncClient,
) where
Fut: std::future::Future + Send + 'static,
Handler: Fn(String, AppEvent) -> Fut + Send + 'static,
{
loop {
// previously published messages should be republished after reconnection.
let event = event_loop.poll().await;
let event = match event {
Ok(notif) => {
println!("Event = {notif:?}");
notif
}
Err(error) => {
println!("Error = {error:?}");
break;
}
};
let publish = match event {
Event::Incoming(Incoming::Publish(publish)) => {
publish
}
Event::Incoming(_) => {
continue;
}
Event::Outgoing(_) => {
continue;
}
};
let Publish {
pkid: _,
dup: _,
qos: _,
retain: _,
topic,
payload,
} = &publish;
match bincode::deserialize::<AppEvent>(&*payload) {
Ok(msg) => {
event_handler(topic.clone(), msg).await;
}
Err(e) => {
tracing::warn!("Invalid message: {e}");
}
};
// this time we will ack incoming publishes.
// Its important not to block eventloop as this can cause deadlock.
let c = client.clone();
tokio::spawn(async move {
c.ack(&publish).await.unwrap();
});
}
}

View File

@ -0,0 +1,8 @@
[package]
name = "identity-agent"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

18
crates/router/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "router"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
chrono = { version = "0.4.31", default-features = false, features = ["serde", "clock", "pure-rust-locales"] }
futures = { version = "0.3.29", default-features = false, features = ["async-await", "futures-executor", "thread-pool", "compat", "io-compat", "std"] }
futures-util = { version = "0.3.29", default-features = false, features = ["async-await", "futures-channel", "futures-io", "futures-macro"] }
serde = { version = "1.0.193", features = ["serde_derive"] }
serde_toml = { version = "0.0.1", default-features = false }
thiserror = "1.0.51"
tokio = { version = "1.35.1", default-features = false, features = ["full"] }
toml = { version = "0.8.8", default-features = false, features = ["parse", "indexmap", "preserve_order"] }
tracing = "0.1.40"
uuid = { version = "1.6.1", default-features = false, features = ["v4", "serde"] }

View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

View File

@ -0,0 +1,19 @@
[package]
name = "sessions-agent"
version = "0.1.0"
edition = "2021"
[dependencies]
chrono = { version = "0.4.31", default-features = false, features = ["serde", "clock", "pure-rust-locales"] }
futures = { version = "0.3.29", default-features = false, features = ["async-await", "futures-executor", "thread-pool", "compat", "io-compat", "std"] }
futures-util = { version = "0.3.29", default-features = false, features = ["async-await", "futures-channel", "futures-io", "futures-macro"] }
serde = { version = "1.0.193", features = ["serde_derive"] }
serde_toml = { version = "0.0.1", default-features = false }
thiserror = "1.0.51"
tokio = { version = "1.35.1", default-features = false, features = ["full"] }
toml = { version = "0.8.8", default-features = false, features = ["parse", "indexmap", "preserve_order"] }
tracing = "0.1.40"
uuid = { version = "1.6.1", default-features = false, features = ["v4", "serde"] }
events = { path = "../events" }
actix = "0.13.1"
actix-web = "4.4.1"

View File

@ -0,0 +1,3 @@
#[actix::main]
async fn main() {
}

View File

@ -0,0 +1,294 @@
use std::collections::HashSet;
use std::hash::{DefaultHasher, Hasher};
use std::num::TryFromIntError;
use std::path::PathBuf;
use std::sync::Arc;
use chrono::NaiveDateTime;
use database::chrono::Utc;
use database::sea_orm::ActiveValue::*;
use database::{chrono, sessions, uuid};
use jsonwebtoken::*;
use serde::{Deserialize, Serialize};
use tracing::*;
pub static KEYS_PATH: &str = "./config";
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
#[serde(rename = "sum")]
pub hash_sum: i64,
#[serde(rename = "sub")]
pub subject: i32,
#[serde(rename = "aud")]
pub audience: String,
pub role: String,
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "nbt")]
pub not_before_time: NaiveDateTime,
#[serde(rename = "exp")]
pub exp: NaiveDateTime,
#[serde(rename = "iat")]
pub issued_at: NaiveDateTime,
#[serde(rename = "jti")]
pub jwt_unique_identifier: uuid::Uuid,
}
impl Claims {
pub async fn new(
subject: i32,
audience: String,
role: String,
issuer: Option<String>,
not_before_time: chrono::NaiveDateTime,
issued_at: chrono::NaiveDateTime,
expiration_duration: Option<chrono::Duration>,
) -> Result<Self, std::num::TryFromIntError> {
let issuer = issuer.unwrap_or_else(|| "ergokeyboard".to_owned());
let mut claims = Claims {
subject,
audience,
role,
issuer,
hash_sum: 0,
not_before_time,
exp: issued_at + expiration_duration.unwrap_or(chrono::Duration::days(365)),
issued_at,
jwt_unique_identifier: uuid::Uuid::new_v4(),
};
let hash_sum = generate_hash_sum(&claims).await as i64;
claims.hash_sum = hash_sum;
Ok(claims)
}
}
macro_rules! cmp_both {
($l: expr, $r: expr, $($field: ident),+) => {
$(
$l.$field == $r.$field &&
)+ true
}
}
impl PartialEq<sessions::Model> for Claims {
fn eq(&self, s: &sessions::Model) -> bool {
cmp_both!(
self,
s,
hash_sum,
subject,
audience,
role,
issuer,
not_before_time,
issued_at,
jwt_unique_identifier
) && self.exp == s.expiration_time
}
}
#[derive(Debug, thiserror::Error)]
pub enum KeysError {
#[error("Decode key failed on file system error: {0}")]
DecodeKeyIo(std::io::Error),
#[error("Decode key failed on file system error: {0}")]
EncodeKeyIo(std::io::Error),
#[error("Decode key failed to parse ed25519 key: {0}")]
DecodeKeyParsing(jsonwebtoken::errors::Error),
#[error("Encode key failed to parse ed25519 key: {0}")]
EncodeKeyParsing(jsonwebtoken::errors::Error),
}
pub struct JwtKeysInner {
decode: DecodingKey,
encode: EncodingKey,
}
const DECODE_KEY_NAME: &str = "public.pem";
const ENCODE_KEY_NAME: &str = "private.pem";
#[derive(Clone, derive_more::Deref)]
pub struct JwtKeys(Arc<JwtKeysInner>);
impl JwtKeys {
pub fn load(config_path: PathBuf) -> Result<Self, KeysError> {
Ok(Self(Arc::new(JwtKeysInner {
decode: DecodingKey::from_ed_pem(
&std::fs::read(config_path.join(DECODE_KEY_NAME))
.map_err(KeysError::DecodeKeyIo)?,
)
.map_err(KeysError::DecodeKeyParsing)?,
encode: EncodingKey::from_ed_pem(
&std::fs::read(config_path.join(ENCODE_KEY_NAME))
.map_err(KeysError::EncodeKeyIo)?,
)
.map_err(KeysError::EncodeKeyParsing)?,
})))
}
}
pub async fn generate_token(
subject: i32,
audience: String,
role: String,
issuer: Option<String>,
not_before_time: chrono::NaiveDateTime,
issued_at: chrono::NaiveDateTime,
expiration_duration: Option<chrono::Duration>,
) -> Result<database::sessions::ActiveModel, TryFromIntError> {
let claims = Claims::new(
subject,
audience,
role,
issuer.clone(),
not_before_time,
issued_at,
expiration_duration,
)
.await?;
Ok(database::sessions::ActiveModel {
hash_sum: Set(claims.hash_sum),
subject: Set(claims.subject),
audience: Set(claims.audience),
role: Set(claims.role),
issuer: Set(claims.issuer),
not_before_time: Set(claims.not_before_time),
expiration_time: Set(claims.exp),
issued_at: Set(claims.issued_at),
jwt_unique_identifier: Set(claims.jwt_unique_identifier),
..Default::default()
})
}
pub async fn generate_hash_sum(claims: &Claims) -> i64 {
let Claims {
subject,
hash_sum: _,
audience,
role,
issuer,
not_before_time,
exp: expiration_time,
issued_at,
jwt_unique_identifier,
} = claims;
let mut hasher = DefaultHasher::default();
hasher.write_i32(*subject);
hasher.write(audience.as_bytes());
hasher.write(role.as_bytes());
hasher.write(issuer.as_bytes());
hasher.write_i64(not_before_time.timestamp_nanos_opt().expect("invalid NBT"));
hasher.write_i64(expiration_time.timestamp_nanos_opt().expect("invalid EXP"));
hasher.write_i64(issued_at.timestamp_nanos_opt().expect("invalid IAT"));
hasher.write(jwt_unique_identifier.as_bytes());
hasher.finish() as i64
}
#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("Given JWT text is not valid")]
InvalidString,
#[error("Can't load accounts from database")]
FetchAccounts,
#[error("Can't load sessions from database")]
FetchSessions,
#[error("Account for given ID does not exists")]
NoAccount,
#[error("Given token does not exists")]
UnknownToken,
#[error("Given token expired")]
Expired,
}
pub async fn validate(
db_client: database::DatabaseConnection,
keys: JwtKeys,
token: &str,
) -> Result<database::accounts::Model, ValidationError> {
use database::*;
let mut validation = jsonwebtoken::Validation::new(Algorithm::EdDSA);
validation.validate_exp = false;
validation.required_spec_claims = HashSet::new();
validation.set_audience(&["Web"]);
validation.set_issuer(&["ergokeyboard"]);
tracing::info!("decoding token: {token:?}");
let token = match jsonwebtoken::decode::<Claims>(
&token,
&keys.decode,
&validation,
) {
Err(e) => {
warn!("Failed to decode token: {e}");
return Err(ValidationError::InvalidString);
}
Ok(token) => token.claims,
};
tracing::trace!("claims are: {token:?}");
let hash = generate_hash_sum(&token).await as i64;
let Ok(mut rows) = Sessions::find()
.filter(entities::sessions::Column::HashSum.eq(hash))
.all(&db_client)
.await
else {
return Err(ValidationError::FetchSessions);
};
let Some(found_idx) = rows.iter().position(|row| token == *row).clone() else {
return Err(ValidationError::UnknownToken);
};
let found = rows.remove(found_idx);
if found.expiration_time < Utc::now().naive_utc() {
return Err(ValidationError::Expired);
}
match Accounts::find()
.filter(database::accounts::Column::Id.eq(found.subject))
.one(&db_client)
.await
{
Err(e) => {
error!("Failed to load account for {found:?}: {e}");
Err(ValidationError::FetchAccounts)
}
Ok(None) => Err(ValidationError::NoAccount),
Ok(Some(account)) => Ok(account),
}
}
pub async fn create_jwt_string(
keys: JwtKeys,
claims: &Claims,
) -> Result<String, jsonwebtoken::errors::Error> {
jsonwebtoken::encode(
&jsonwebtoken::Header::new(Algorithm::EdDSA),
claims,
&keys.encode,
)
}
#[cfg(test)]
mod tests {
use database::chrono::Utc;
use super::*;
use std::path::Path;
#[tokio::test]
async fn create_string() {
let keys = JwtKeys::load(Path::new("./config").to_owned()).unwrap();
let claims = Claims::new(
234,
"jaosidf".into(),
"User".into(),
None,
Utc::now().naive_utc(),
Utc::now().naive_utc(),
None,
)
.await
.unwrap();
let _text = create_jwt_string(keys, &claims).await.unwrap();
}
}

View File

@ -0,0 +1,18 @@
[package]
name = "shared-config"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
chrono = { version = "0.4.31", default-features = false, features = ["serde", "clock", "pure-rust-locales"] }
futures = { version = "0.3.29", default-features = false, features = ["async-await", "futures-executor", "thread-pool", "compat", "io-compat", "std"] }
futures-util = { version = "0.3.29", default-features = false, features = ["async-await", "futures-channel", "futures-io", "futures-macro"] }
serde = { version = "1.0.193", features = ["serde_derive"] }
serde_toml = { version = "0.0.1", default-features = false }
thiserror = "1.0.51"
tokio = { version = "1.35.1", default-features = false, features = ["full"] }
toml = { version = "0.8.8", default-features = false, features = ["parse", "indexmap", "preserve_order"] }
tracing = "0.1.40"
uuid = { version = "1.6.1", default-features = false, features = ["v4", "serde"] }

View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

0
data/rumqttd/.gitkeep Normal file
View File