use std::net::{IpAddr, Ipv4Addr}; use channels::accounts::rpc::Accounts; use channels::accounts::{me, register}; use channels::AsyncClient; use config::SharedAppConfig; use futures::future::{self}; use futures::stream::StreamExt; use rumqttc::QoS; use tarpc::context; use tarpc::server::incoming::Incoming; use tarpc::server::{self, Channel}; use tarpc::tokio_serde::formats::Bincode; use crate::actions; use crate::db::Database; #[derive(Debug, Copy, Clone, serde::Serialize, thiserror::Error)] #[serde(rename_all = "kebab-case", tag = "account")] pub enum Error { #[error("Unable to send or receive msg from database")] DbCritical, #[error("Failed to load account data")] Account, #[error("Failed to load account addresses")] Addresses, #[error("Unable to save record")] Saving, #[error("Unable to hash password")] Hashing, } #[derive(Clone)] struct AccountsServer { db: Database, config: SharedAppConfig, mqtt_client: AsyncClient, } #[tarpc::server] impl Accounts for AccountsServer { async fn me(self, _: context::Context, input: me::Input) -> me::Output { let res = actions::me(input.account_id, self.db).await; tracing::info!("ME result: {:?}", res); res } async fn register_account( self, _: context::Context, input: register::Input, ) -> register::Output { use channels::accounts::{Error, Topic}; let res = actions::create_account(input, &self.db, self.config).await; tracing::info!("REGISTER result: {:?}", res); match res { Ok(account) => { self.mqtt_client .publish_or_log(Topic::AccountCreated, QoS::AtLeastOnce, true, &account) .await; register::Output { account: Some(account), error: None, } } Err(_e) => register::Output { account: None, error: Some(Error::Account), }, } } } pub async fn start(config: SharedAppConfig, db: Database, mqtt_client: AsyncClient) { let port = { config.lock().account_manager().port }; let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), port); let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Bincode::default) .await .unwrap(); tracing::info!("Starting account rpc at {}", listener.local_addr()); listener.config_mut().max_frame_length(usize::MAX); listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 8 per IP. .max_channels_per_key(8, |t| t.transport().peer_addr().unwrap().ip()) .max_concurrent_requests_per_channel(20) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { channel.execute( AccountsServer { db: db.clone(), config: config.clone(), mqtt_client: mqtt_client.clone(), } .serve(), ) }) // Max 10 channels. .buffer_unordered(10) .for_each(|_| async {}) .await; tracing::info!("RPC channel closed"); }