use std::net::{IpAddr, Ipv4Addr}; use channels::account::{CreateAccount, MeResult, RegisterResult}; use channels::AsyncClient; use config::SharedAppConfig; use database_manager::Database; use futures::future::{self}; use futures::stream::StreamExt; use rumqttc::QoS; use tarpc::context; use tarpc::server::incoming::Incoming; use tarpc::server::{self, Channel}; use tarpc::tokio_serde::formats::Json; use crate::actions; #[derive(Debug, Copy, Clone, serde::Serialize, thiserror::Error)] #[serde(rename_all = "kebab-case", tag = "account")] pub enum Error { #[error("Unable to send or receive msg from database")] DbCritical, #[error("Failed to load account data")] Account, #[error("Failed to load account addresses")] Addresses, #[error("Unable to save record")] Saving, #[error("Unable to hash password")] Hashing, #[error("{0}")] Db(#[from] database_manager::Error), } #[derive(Clone)] struct AccountsServer { db: Database, config: SharedAppConfig, mqtt_client: AsyncClient, } #[tarpc::server] impl channels::account::rpc::Accounts for AccountsServer { async fn me(self, _: context::Context, account_id: model::AccountId) -> MeResult { let res = actions::me(account_id, self.db).await; tracing::info!("ME result: {:?}", res); res } async fn register_account(self, _: context::Context, details: CreateAccount) -> RegisterResult { let res = actions::create_account(details, &self.db, self.config).await; tracing::info!("REGISTER result: {:?}", res); match res { Ok(account) => { self.mqtt_client .publish_or_log( channels::account::Topic::AccountCreated, QoS::AtLeastOnce, true, &account, ) .await; RegisterResult { account: Some(account), error: None, } } Err(_e) => RegisterResult { account: None, error: Some(channels::account::Error::Account), }, } } } pub async fn start(config: SharedAppConfig, db: Database, mqtt_client: AsyncClient) { use channels::account::rpc::Accounts; let port = { config.lock().account_manager().port }; let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), port); let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default) .await .unwrap(); tracing::info!("Starting account rpc at {}", listener.local_addr()); listener.config_mut().max_frame_length(usize::MAX); listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 8 per IP. .max_channels_per_key(8, |t| t.transport().peer_addr().unwrap().ip()) .max_concurrent_requests_per_channel(20) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { channel.execute( AccountsServer { db: db.clone(), config: config.clone(), mqtt_client: mqtt_client.clone(), } .serve(), ) }) // Max 10 channels. .buffer_unordered(10) .for_each(|_| async {}) .await; tracing::info!("RPC channel closed"); }