use std::{ collections::{BTreeSet, HashMap, HashSet}, fmt::{Debug, Display}, fs::File, io::{BufReader, BufWriter}, ops::RangeInclusive, path::{Path, PathBuf}, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; use eyre::eyre; use serde::{Deserialize, Serialize, Serializer}; use tokio::{ net::TcpListener, sync::{watch::Receiver, Mutex}, task::JoinHandle, time::Instant, }; use tracing::{debug, error, info, instrument, warn}; use crate::{ constants::{CACHE_STORE_INTERVAL, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME}, packets::Packet, spawn, Config, Number, Port, UnixTimestamp, }; #[derive(Default, Serialize, Deserialize)] pub struct PortHandler { #[serde(skip_deserializing)] #[serde(serialize_with = "serialize_last_update")] pub last_update: Option, #[serde(skip)] pub change_sender: Option>, #[serde(skip_deserializing)] rejectors: HashMap, allowed_ports: AllowedList, #[serde(skip)] free_ports: HashSet, errored_ports: BTreeSet<(UnixTimestamp, Port)>, allocated_ports: HashMap, pub port_state: HashMap, #[cfg(feature = "debug_server")] #[serde(default)] pub names: HashMap, } #[allow(clippy::missing_errors_doc)] pub fn serialize_last_update( last_update: &Option, serializer: S, ) -> Result { last_update .and_then(|instant| { Some( (SystemTime::now() + instant.elapsed()) .duration_since(UNIX_EPOCH) .ok()? .as_secs(), ) }) .serialize(serializer) } #[instrument(skip(port_handler, change_receiver))] pub async fn cache_daemon( port_handler: Arc>, cache_path: PathBuf, mut change_receiver: Receiver, ) { let mut last_store = Instant::now() - 2 * CACHE_STORE_INTERVAL; let mut change_timeout = None; loop { if let Some(change_timeout) = change_timeout.take() { tokio::time::timeout(change_timeout, change_receiver.changed()) .await .unwrap_or(Ok(())) } else { change_receiver.changed().await } .expect("failed to wait for cache changes"); let time_since_last_store = last_store.elapsed(); if time_since_last_store >= CACHE_STORE_INTERVAL { let port_handler = port_handler.lock().await; last_store = Instant::now(); if let Err(err) = port_handler.store(&cache_path) { error!("failed to store cache: {err:?}"); } } else { change_timeout = Some(CACHE_STORE_INTERVAL - time_since_last_store); } } } #[derive(Hash, PartialEq, Eq)] struct DisplayAsDebug(T); impl Debug for DisplayAsDebug { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } #[derive(Default, Serialize, Deserialize)] pub struct PortState { #[serde(deserialize_with = "deserialize_last_change")] last_change: UnixTimestamp, #[serde(skip_deserializing)] status: PortStatus, } fn deserialize_last_change<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Ok(match Option::::deserialize(deserializer)? { Some(timestamp) => timestamp, None => now(), }) } fn now() -> UnixTimestamp { SystemTime::now() .duration_since(UNIX_EPOCH) .expect("timestamp overflow") .as_secs() } impl PortState { pub fn new_state(&mut self, status: PortStatus) { self.last_change = now(); self.status = status; } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, PartialOrd, Ord)] #[serde(rename_all = "snake_case")] pub enum PortStatus { InCall, Idle, Disconnected, } impl Default for PortStatus { fn default() -> Self { Self::Disconnected } } #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct AllowedList(Vec>); impl AllowedList { #[must_use] pub fn is_allowed(&self, port: Port) -> bool { self.0.iter().any(|range| range.contains(&port)) } #[must_use] pub fn is_empty(&self) -> bool { self.0.is_empty() } } impl PortHandler { pub fn register_update(&mut self) { let now = std::time::Instant::now(); self.last_update = Some(now); self.change_sender .as_ref() .expect("PortHandler is missing its change_sender") .send(now) .expect("failed to notify cache writer"); } #[allow(clippy::missing_errors_doc)] #[instrument(skip(self))] pub fn store(&self, cache: &Path) -> std::io::Result<()> { debug!("storing cache"); let temp_file = cache.with_extension("temp"); let mut value = serde_json::to_value(self)?; let value_object = value.as_object_mut().unwrap(); value_object.remove("rejectors").unwrap(); value_object.remove("last_update").unwrap(); value_object .get_mut("port_state") .unwrap() .as_object_mut() .unwrap() .values_mut() .for_each(|value| { let value_object = value.as_object_mut().unwrap(); // it does not make sense to store when the did anything else other than disconnect // because when we restart the server it will no longer be connected if value_object.get("status").unwrap().as_str().unwrap() != "disconnected" { *value_object.get_mut("last_change").unwrap() = serde_json::Value::Null; } value_object.remove("status").unwrap(); }); serde_json::to_writer(BufWriter::new(File::create(&temp_file)?), &value)?; std::fs::rename(temp_file, cache)?; Ok(()) } #[allow(clippy::missing_errors_doc)] pub fn load(cache: &Path) -> std::io::Result { info!("loading cache"); Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?) } #[must_use] #[instrument(skip(change_sender))] pub fn load_or_default( path: &Path, change_sender: tokio::sync::watch::Sender, ) -> Self { let mut this = Self::load(path).unwrap_or_else(|error| { error!(?path, %error, "failed to parse cache file"); Self::default() }); this.change_sender = Some(change_sender); this } pub fn update_allowed_ports(&mut self, allowed_ports: &AllowedList) { self.register_update(); self.allowed_ports = allowed_ports.clone(); self.free_ports.clear(); // remove all ports self.free_ports .extend(self.allowed_ports.0.iter().cloned().flatten()); // add allowed ports self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach self.errored_ports .retain(|(_, port)| self.allowed_ports.is_allowed(*port)); // remove errored ports that are no longer allowed self.allocated_ports .retain(|_, port| self.allowed_ports.is_allowed(*port)); // remove allocated ports that are no longer allowed self.port_state .retain(|port, _| self.allowed_ports.is_allowed(*port)); // remove port states that are no longer allowed self.free_ports.retain(|port| { let is_allocted = self .allocated_ports .iter() .any(|(_, allocated_port)| allocated_port == port); let is_errored = self .errored_ports .iter() .any(|(_, errored_port)| errored_port == port); !(is_allocted || is_errored) }); } #[instrument(skip(self, listener))] pub fn start_rejector(&mut self, port: Port, listener: TcpListener, packet: Packet) { info!(port, ?packet, "starting rejector"); let port_guard = Rejector::start(listener, packet); if self.rejectors.insert(port, port_guard).is_some() { unreachable!("Tried to start rejector that is already running. This should have been impossible since it requires two listeners on the same port."); } } #[instrument(skip(self))] pub async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> { info!(port, "stopping rejector"); Some(self.rejectors.remove(&port)?.stop().await) } /// # Errors /// - the rejector must be running pub async fn change_rejector( &mut self, port: Port, f: impl FnOnce(&mut Packet), ) -> eyre::Result<()> { let (listener, mut packet) = self .stop_rejector(port) .await .ok_or_else(|| eyre!("tried to stop rejector that is not running"))?; f(&mut packet); self.start_rejector(port, listener, packet); Ok(()) } } struct Rejector { state: Arc<(Mutex, Packet)>, handle: JoinHandle<()>, } impl Serialize for Rejector { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let packet = &self.state.1; match packet.as_string() { Some(string) if string.chars().all(|c| !c.is_control()) => string.serialize(serializer), _ => packet.data().serialize(serializer), } } } impl Debug for Rejector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rejector") .field("message", &self.state.1) .finish() } } impl Rejector { #[instrument(skip(listener))] fn start(listener: TcpListener, packet: Packet) -> Self { let port = listener.local_addr().map(|addr| addr.port()).unwrap_or(0); let state = Arc::new((Mutex::new(listener), packet)); let handle = { let state = state.clone(); spawn(&format!("rejector for port {port}",), async move { let (listener, packet) = state.as_ref(); let listener = listener.lock().await; loop { if let Ok((mut socket, _)) = listener.accept().await { let (_, mut writer) = socket.split(); _ = packet.send(&mut writer).await; } } }) }; Self { state, handle } } #[instrument(skip(self))] async fn stop(self) -> (TcpListener, Packet) { self.handle.abort(); _ = self.handle.await; let (listener, packet) = Arc::try_unwrap(self.state).unwrap(); (listener.into_inner(), packet) } } impl PortHandler { #[instrument(skip(self, config))] pub fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option { let port = if let Some(port) = self.allocated_ports.get(&number) { let already_connected = self .port_state .get(port) .map_or(false, |state| state.status != PortStatus::Disconnected); if already_connected { None } else { Some(*port) } } else { let port = if let Some(&port) = self.free_ports.iter().next() { self.register_update(); self.free_ports.remove(&port); port } else { self.try_recover_port(config)? }; if self.allocated_ports.insert(number, port).is_some() { unreachable!("allocated port twice"); } Some(port) }; if let Some(port) = port { info!(port, "allocated"); } port } #[instrument(skip(self, config))] fn try_recover_port(&mut self, config: &Config) -> Option { let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let mut recovered_port = None; self.errored_ports = std::mem::take(&mut self.errored_ports) .into_iter() .filter_map(|(mut timestamp, port)| { if recovered_port.is_none() && now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME { info!( port, last_try = ?Duration::from_secs(now.as_secs()).saturating_sub(Duration::from_secs(timestamp)), "retrying errored port", ); match std::net::TcpListener::bind((config.listen_addr.ip(), port)) { Ok(_) => { recovered_port = Some((timestamp, port)); return None; } Err(_) => timestamp = now.as_secs(), } } else { info!( port, last_try = ?Duration::from_secs(now.as_secs()).saturating_sub(Duration::from_secs(timestamp)), "skipped retrying errored port", ); } Some((timestamp, port)) }) .collect(); if let Some((_, port)) = recovered_port { self.register_update(); info!(port, "recovered port"); return Some(port); } let removable_entry = self.allocated_ports.iter().find(|(_, port)| { self.port_state.get(port).map_or(true, |port_state| { port_state.status == PortStatus::Disconnected && now.saturating_sub(Duration::from_secs(port_state.last_change)) >= PORT_OWNERSHIP_TIMEOUT }) }); if let Some((&old_number, &port)) = removable_entry { self.register_update(); info!(port, old_number, "reused port"); assert!(self.allocated_ports.remove(&old_number).is_some()); #[cfg(feature = "debug_server")] self.names.remove(&old_number); return Some(port); } None // TODO: are there more ways? } #[instrument(skip(self))] pub fn mark_port_error(&mut self, number: Number, port: Port) { warn!(port, number, "registering an error on"); self.register_update(); self.errored_ports.insert(( SystemTime::now() .duration_since(UNIX_EPOCH) .expect("timestamp overflow") .as_secs(), port, )); self.allocated_ports.remove(&number); self.free_ports.remove(&port); self.port_state.remove(&port); } }