centralex/src/ports.rs

496 lines
15 KiB
Rust

use std::{
collections::{BTreeSet, HashMap, HashSet},
fmt::{Debug, Display},
fs::File,
io::{BufReader, BufWriter},
ops::RangeInclusive,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use eyre::eyre;
use serde::{Deserialize, Serialize, Serializer};
use tokio::{
net::TcpListener,
sync::{watch::Receiver, Mutex},
task::JoinHandle,
time::Instant,
};
use tracing::{debug, error, info, instrument, warn};
use crate::{
constants::{CACHE_STORE_INTERVAL, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME},
packets::Packet,
spawn, Config, Number, Port, UnixTimestamp,
};
#[derive(Default, Serialize, Deserialize)]
pub struct PortHandler {
#[serde(skip_deserializing)]
#[serde(serialize_with = "serialize_last_update")]
pub last_update: Option<std::time::Instant>,
#[serde(skip)]
pub change_sender: Option<tokio::sync::watch::Sender<std::time::Instant>>,
#[serde(skip_deserializing)]
rejectors: HashMap<Port, Rejector>,
allowed_ports: AllowedList,
#[serde(skip)]
free_ports: HashSet<Port>,
errored_ports: BTreeSet<(UnixTimestamp, Port)>,
allocated_ports: HashMap<Number, Port>,
pub port_state: HashMap<Port, PortState>,
#[cfg(feature = "debug_server")]
#[serde(default)]
pub names: HashMap<Number, String>,
}
#[allow(clippy::missing_errors_doc)]
pub fn serialize_last_update<S: Serializer>(
last_update: &Option<std::time::Instant>,
serializer: S,
) -> Result<S::Ok, S::Error> {
last_update
.and_then(|instant| {
Some(
(SystemTime::now() + instant.elapsed())
.duration_since(UNIX_EPOCH)
.ok()?
.as_secs(),
)
})
.serialize(serializer)
}
#[instrument(skip(port_handler, change_receiver))]
pub async fn cache_daemon(
port_handler: Arc<Mutex<PortHandler>>,
cache_path: PathBuf,
mut change_receiver: Receiver<std::time::Instant>,
) {
let mut last_store = Instant::now() - 2 * CACHE_STORE_INTERVAL;
let mut change_timeout = None;
loop {
if let Some(change_timeout) = change_timeout.take() {
tokio::time::timeout(change_timeout, change_receiver.changed())
.await
.unwrap_or(Ok(()))
} else {
change_receiver.changed().await
}
.expect("failed to wait for cache changes");
let time_since_last_store = last_store.elapsed();
if time_since_last_store >= CACHE_STORE_INTERVAL {
let port_handler = port_handler.lock().await;
last_store = Instant::now();
if let Err(err) = port_handler.store(&cache_path) {
error!("failed to store cache: {err:?}");
}
} else {
change_timeout = Some(CACHE_STORE_INTERVAL - time_since_last_store);
}
}
}
#[derive(Hash, PartialEq, Eq)]
struct DisplayAsDebug<T: Display>(T);
impl<T: Display> Debug for DisplayAsDebug<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Default, Serialize, Deserialize)]
pub struct PortState {
#[serde(deserialize_with = "deserialize_last_change")]
last_change: UnixTimestamp,
#[serde(skip_deserializing)]
status: PortStatus,
}
fn deserialize_last_change<'de, D>(deserializer: D) -> Result<UnixTimestamp, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(match Option::<UnixTimestamp>::deserialize(deserializer)? {
Some(timestamp) => timestamp,
None => now(),
})
}
fn now() -> UnixTimestamp {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("timestamp overflow")
.as_secs()
}
impl PortState {
pub fn new_state(&mut self, status: PortStatus) {
self.last_change = now();
self.status = status;
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, PartialOrd, Ord)]
#[serde(rename_all = "snake_case")]
pub enum PortStatus {
InCall,
Idle,
Disconnected,
}
impl Default for PortStatus {
fn default() -> Self {
Self::Disconnected
}
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct AllowedList(Vec<RangeInclusive<u16>>);
impl AllowedList {
#[must_use]
pub fn is_allowed(&self, port: Port) -> bool {
self.0.iter().any(|range| range.contains(&port))
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl PortHandler {
pub fn register_update(&mut self) {
let now = std::time::Instant::now();
self.last_update = Some(now);
self.change_sender
.as_ref()
.expect("PortHandler is missing its change_sender")
.send(now)
.expect("failed to notify cache writer");
}
#[allow(clippy::missing_errors_doc)]
#[instrument(skip(self))]
pub fn store(&self, cache: &Path) -> std::io::Result<()> {
debug!("storing cache");
let temp_file = cache.with_extension("temp");
let mut value = serde_json::to_value(self)?;
let value_object = value.as_object_mut().unwrap();
value_object.remove("rejectors").unwrap();
value_object.remove("last_update").unwrap();
value_object
.get_mut("port_state")
.unwrap()
.as_object_mut()
.unwrap()
.values_mut()
.for_each(|value| {
let value_object = value.as_object_mut().unwrap();
// it does not make sense to store when the did anything else other than disconnect
// because when we restart the server it will no longer be connected
if value_object.get("status").unwrap().as_str().unwrap() != "disconnected" {
*value_object.get_mut("last_change").unwrap() = serde_json::Value::Null;
}
value_object.remove("status").unwrap();
});
serde_json::to_writer(BufWriter::new(File::create(&temp_file)?), &value)?;
std::fs::rename(temp_file, cache)?;
Ok(())
}
#[allow(clippy::missing_errors_doc)]
pub fn load(cache: &Path) -> std::io::Result<Self> {
info!("loading cache");
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
}
#[must_use]
#[instrument(skip(change_sender))]
pub fn load_or_default(
path: &Path,
change_sender: tokio::sync::watch::Sender<std::time::Instant>,
) -> Self {
let mut this = Self::load(path).unwrap_or_else(|error| {
error!(?path, %error, "failed to parse cache file");
Self::default()
});
this.change_sender = Some(change_sender);
this
}
pub fn update_allowed_ports(&mut self, allowed_ports: &AllowedList) {
self.register_update();
self.allowed_ports = allowed_ports.clone();
self.free_ports.clear(); // remove all ports
self.free_ports
.extend(self.allowed_ports.0.iter().cloned().flatten()); // add allowed ports
self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach
self.errored_ports
.retain(|(_, port)| self.allowed_ports.is_allowed(*port)); // remove errored ports that are no longer allowed
self.allocated_ports
.retain(|_, port| self.allowed_ports.is_allowed(*port)); // remove allocated ports that are no longer allowed
self.port_state
.retain(|port, _| self.allowed_ports.is_allowed(*port)); // remove port states that are no longer allowed
self.free_ports.retain(|port| {
let is_allocted = self
.allocated_ports
.iter()
.any(|(_, allocated_port)| allocated_port == port);
let is_errored = self
.errored_ports
.iter()
.any(|(_, errored_port)| errored_port == port);
!(is_allocted || is_errored)
});
}
#[instrument(skip(self, listener))]
pub fn start_rejector(&mut self, port: Port, listener: TcpListener, packet: Packet) {
info!(port, ?packet, "starting rejector");
let port_guard = Rejector::start(listener, packet);
if self.rejectors.insert(port, port_guard).is_some() {
unreachable!("Tried to start rejector that is already running. This should have been impossible since it requires two listeners on the same port.");
}
}
#[instrument(skip(self))]
pub async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> {
info!(port, "stopping rejector");
Some(self.rejectors.remove(&port)?.stop().await)
}
/// # Errors
/// - the rejector must be running
pub async fn change_rejector(
&mut self,
port: Port,
f: impl FnOnce(&mut Packet),
) -> eyre::Result<()> {
let (listener, mut packet) = self
.stop_rejector(port)
.await
.ok_or_else(|| eyre!("tried to stop rejector that is not running"))?;
f(&mut packet);
self.start_rejector(port, listener, packet);
Ok(())
}
}
struct Rejector {
state: Arc<(Mutex<TcpListener>, Packet)>,
handle: JoinHandle<()>,
}
impl Serialize for Rejector {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let packet = &self.state.1;
match packet.as_string() {
Some(string) if string.chars().all(|c| !c.is_control()) => string.serialize(serializer),
_ => packet.data().serialize(serializer),
}
}
}
impl Debug for Rejector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Rejector")
.field("message", &self.state.1)
.finish()
}
}
impl Rejector {
#[instrument(skip(listener))]
fn start(listener: TcpListener, packet: Packet) -> Self {
let port = listener.local_addr().map(|addr| addr.port()).unwrap_or(0);
let state = Arc::new((Mutex::new(listener), packet));
let handle = {
let state = state.clone();
spawn(&format!("rejector for port {port}",), async move {
let (listener, packet) = state.as_ref();
let listener = listener.lock().await;
loop {
if let Ok((mut socket, _)) = listener.accept().await {
let (_, mut writer) = socket.split();
_ = packet.send(&mut writer).await;
}
}
})
};
Self { state, handle }
}
#[instrument(skip(self))]
async fn stop(self) -> (TcpListener, Packet) {
self.handle.abort();
_ = self.handle.await;
let (listener, packet) = Arc::try_unwrap(self.state).unwrap();
(listener.into_inner(), packet)
}
}
impl PortHandler {
#[instrument(skip(self, config))]
pub fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option<Port> {
let port = if let Some(port) = self.allocated_ports.get(&number) {
let already_connected = self
.port_state
.get(port)
.map_or(false, |state| state.status != PortStatus::Disconnected);
if already_connected {
None
} else {
Some(*port)
}
} else {
let port = if let Some(&port) = self.free_ports.iter().next() {
self.register_update();
self.free_ports.remove(&port);
port
} else {
self.try_recover_port(config)?
};
if self.allocated_ports.insert(number, port).is_some() {
unreachable!("allocated port twice");
}
Some(port)
};
if let Some(port) = port {
info!(port, "allocated");
}
port
}
#[instrument(skip(self, config))]
fn try_recover_port(&mut self, config: &Config) -> Option<Port> {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let mut recovered_port = None;
self.errored_ports = std::mem::take(&mut self.errored_ports)
.into_iter()
.filter_map(|(mut timestamp, port)| {
if recovered_port.is_none()
&& now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME
{
info!(
port,
last_try = ?Duration::from_secs(now.as_secs()).saturating_sub(Duration::from_secs(timestamp)),
"retrying errored port",
);
match std::net::TcpListener::bind((config.listen_addr.ip(), port)) {
Ok(_) => {
recovered_port = Some((timestamp, port));
return None;
}
Err(_) => timestamp = now.as_secs(),
}
} else {
info!(
port,
last_try = ?Duration::from_secs(now.as_secs()).saturating_sub(Duration::from_secs(timestamp)),
"skipped retrying errored port",
);
}
Some((timestamp, port))
})
.collect();
if let Some((_, port)) = recovered_port {
self.register_update();
info!(port, "recovered port");
return Some(port);
}
let removable_entry = self.allocated_ports.iter().find(|(_, port)| {
self.port_state.get(port).map_or(true, |port_state| {
port_state.status == PortStatus::Disconnected
&& now.saturating_sub(Duration::from_secs(port_state.last_change))
>= PORT_OWNERSHIP_TIMEOUT
})
});
if let Some((&old_number, &port)) = removable_entry {
self.register_update();
info!(port, old_number, "reused port");
assert!(self.allocated_ports.remove(&old_number).is_some());
#[cfg(feature = "debug_server")]
self.names.remove(&old_number);
return Some(port);
}
None // TODO: are there more ways?
}
#[instrument(skip(self))]
pub fn mark_port_error(&mut self, number: Number, port: Port) {
warn!(port, number, "registering an error on");
self.register_update();
self.errored_ports.insert((
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("timestamp overflow")
.as_secs(),
port,
));
self.allocated_ports.remove(&number);
self.free_ports.remove(&port);
self.port_state.remove(&port);
}
}