From 4870dafbaaff2c75f63723c04f73131fdeb4f6e4 Mon Sep 17 00:00:00 2001 From: soruh Date: Sat, 18 Feb 2023 18:17:59 +0100 Subject: [PATCH] restructure code --- Cargo.toml | 6 +- src/debug_server.rs | 1 + src/main.rs | 466 +++++++++----------------------------------- src/ports.rs | 318 ++++++++++++++++++++++++++++++ 4 files changed, 411 insertions(+), 380 deletions(-) create mode 100644 src/debug_server.rs create mode 100644 src/ports.rs diff --git a/Cargo.toml b/Cargo.toml index b20ed59..0ccf29a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,12 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread", "net", "io-util", "sync", "time"] } anyhow = { version = "1.0.68", features = ["backtrace"] } bytemuck = { version = "1.13.0", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.91" -tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread", "net", "io-util", "sync", "time"] } + +[features] +default = ["debug_server"] +debug_server = [] \ No newline at end of file diff --git a/src/debug_server.rs b/src/debug_server.rs new file mode 100644 index 0000000..03ecf07 --- /dev/null +++ b/src/debug_server.rs @@ -0,0 +1 @@ +pub async fn debug_server() {} diff --git a/src/main.rs b/src/main.rs index e69e3ce..20ea299 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,29 @@ // #![allow(unused)] use std::{ - collections::{BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, - io::{BufReader, BufWriter}, + io::BufReader, net::{SocketAddr, ToSocketAddrs}, - ops::Range, path::{Path, PathBuf}, sync::Arc, - time::{Duration, SystemTime, UNIX_EPOCH}, + time::Duration, }; -use anyhow::{anyhow, bail, Context}; +use anyhow::{bail, Context}; +use debug_server::debug_server; use packets::{Header, Packet, RemConnect}; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Deserializer}; use tokio::{ io::AsyncWriteExt, net::{TcpListener, TcpStream}, select, sync::Mutex, - task::JoinHandle, time::{sleep, Instant}, }; use crate::packets::{dyn_ip_update, PacketKind, REJECT_OOP, REJECT_TIMEOUT}; +use crate::ports::{AllowedPorts, PortHandler, PortStatus}; const AUTH_TIMEOUT: Duration = Duration::from_secs(30); const CALL_ACK_TIMEOUT: Duration = Duration::from_secs(30); @@ -34,14 +33,17 @@ const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60); const PING_TIMEOUT: Duration = Duration::from_secs(30); const SEND_PING_INTERVAL: Duration = Duration::from_secs(20); +#[cfg(feature = "debug_server")] +mod debug_server; mod packets; +mod ports; type Port = u16; type Number = u32; type UnixTimestamp = u64; #[derive(Debug, Deserialize)] -struct Config { +pub struct Config { allowed_ports: AllowedPorts, #[serde(deserialize_with = "parse_socket_addr")] listen_addr: SocketAddr, @@ -68,301 +70,113 @@ impl Config { } } -#[derive(Default, Debug, Serialize, Deserialize)] -struct PortHandler { - #[serde(skip)] - last_update: Option, +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let config = Arc::new(Config::load("config.json")?); - #[serde(skip)] - port_guards: HashMap, - - allowed_ports: AllowedPorts, - - free_ports: HashSet, - errored_ports: BTreeSet<(UnixTimestamp, Port)>, - allocated_ports: HashMap, - - #[serde(skip)] - port_state: HashMap, -} - -#[derive(Default, Debug)] -struct PortState { - last_change: UnixTimestamp, - status: PortStatus, -} - -impl PortState { - fn new_state(&mut self, status: PortStatus) { - self.last_change = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - self.status = status; - } -} - -#[derive(Debug, PartialEq, Eq)] -enum PortStatus { - Disconnected, - Idle, - InCall, -} - -impl Default for PortStatus { - fn default() -> Self { - Self::Disconnected - } -} - -#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -struct AllowedPorts(Vec>); - -impl AllowedPorts { - fn is_allowed(&self, port: Port) -> bool { - self.0.iter().any(|range| range.contains(&port)) - } -} - -impl PortHandler { - fn register_update(&mut self) { - self.last_update = Some(Instant::now()); + if config.allowed_ports.is_empty() { + panic!("no allowed ports"); } - fn store(&self, cache: &Path) -> anyhow::Result<()> { - println!("storing database"); - serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?; - Ok(()) - } + let cache_path = PathBuf::from("cache.json"); - fn load(cache: &Path) -> std::io::Result { - println!("loading database"); - Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?) - } + let mut port_handler = PortHandler::load_or_default(&cache_path); + port_handler.update_allowed_ports(&config.allowed_ports); - fn load_or_default(cache: &Path) -> Self { - Self::load(cache).unwrap_or_else(|err| { - println!("failed to parse cache file at {cache:?} using empty cache. error: {err}"); - Self::default() - }) - } + let port_handler = Arc::new(Mutex::new(port_handler)); - fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) { - self.register_update(); + { + let port_handler = port_handler.clone(); + tokio::spawn(async move { + let mut last_store = None; + loop { + sleep(Duration::from_secs(1)).await; - self.allowed_ports = allowed_ports.clone(); + let port_handler = port_handler.lock().await; - self.free_ports.clear(); - self.free_ports - .extend(self.allowed_ports.0.iter().cloned().flatten()); + if let Some(last_update) = port_handler.last_update { + let should_store = last_store + .map(|last_store| last_update > last_store) + .unwrap_or(true); - self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach - - self.errored_ports - .retain(|(_, port)| self.allowed_ports.is_allowed(*port)); - - self.allocated_ports - .retain(|_, port| self.allowed_ports.is_allowed(*port)); - - self.free_ports.retain(|port| { - self.allocated_ports - .iter() - .find(|(_, allocated_port)| *allocated_port == port) - .is_none() - && self - .errored_ports - .iter() - .find(|(_, errored_port)| errored_port == port) - .is_none() + if should_store { + last_store = Some(last_update); + port_handler.store(&cache_path).unwrap(); + } + } + } }); } - fn start_rejector( - &mut self, - port: Port, - listener: TcpListener, - packet: Packet, - ) -> anyhow::Result<()> { - println!("starting rejector: for port {port} with {packet:?}"); + #[cfg(feature = "debug_server")] + tokio::spawn(debug_server()); - let port_guard = PortGuard::start(listener, packet); + let listener = TcpListener::bind(config.listen_addr).await?; - assert!( - self.port_guards.insert(port, port_guard).is_none(), - "Tried to start rejector that is already running. - This should have been impossible since it requires two listeners on the same port." - ); - Ok(()) - } + while let Ok((mut stream, addr)) = listener.accept().await { + println!("connection from {addr}"); - async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> { - println!("stopping rejector: for port {port}"); + let port_handler = port_handler.clone(); + let config = config.clone(); - Some(self.port_guards.remove(&port)?.stop().await) - } + let mut handler_metadata = HandlerMetadata::default(); - async fn change_rejector( - &mut self, - port: Port, - f: impl FnOnce(&mut Packet), - ) -> anyhow::Result<()> { - let (listener, mut packet) = self - .stop_rejector(port) - .await - .ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?; + tokio::spawn(async move { + let res = + connection_handler(&config, &mut handler_metadata, &port_handler, &mut stream) + .await; - f(&mut packet); + if let Err(err) = res { + println!("client at {addr} had an error: {err}"); - self.start_rejector(port, listener, packet) - } -} + let mut packet = Packet::default(); -struct PortGuard { - state: Arc<(Mutex, Packet)>, - handle: JoinHandle<()>, -} + packet.data.extend_from_slice(err.to_string().as_bytes()); + packet.data.truncate(0xfe); + packet.data.push(0); + packet.header = Header { + kind: PacketKind::Error.raw(), + length: packet.data.len() as u8, + }; -impl Debug for PortGuard { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PortGuard").finish() - } -} + let (_, mut writer) = stream.split(); + let _ = packet.send(&mut writer).await; + } -impl PortGuard { - fn start(listener: TcpListener, packet: Packet) -> Self { - let state = Arc::new((Mutex::new(listener), packet)); + if let Some(port) = handler_metadata.port { + let mut port_handler = port_handler.lock().await; - let handle = { - let state = state.clone(); - - tokio::spawn(async move { - let (listener, packet) = state.as_ref(); - - let listener = listener.lock().await; - - loop { - if let Ok((mut socket, _)) = listener.accept().await { - let (_, mut writer) = socket.split(); - let _ = packet.send(&mut writer).await; - } - } - }) - }; - - Self { state, handle } - } - async fn stop(self) -> (TcpListener, Packet) { - self.handle.abort(); - let _ = self.handle.await; - let (listener, packet) = Arc::try_unwrap(self.state).unwrap(); - (listener.into_inner(), packet) - } -} - -impl PortHandler { - fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option { - if let Some(port) = self.allocated_ports.get(&number) { - let already_connected = self - .port_state - .get(port) - .map(|state| state.status != PortStatus::Disconnected) - .unwrap_or(false); - - return if already_connected { None } else { Some(*port) }; - } - - let port = if let Some(&port) = self.free_ports.iter().next() { - self.register_update(); - self.free_ports.remove(&port); - port - } else { - self.try_recover_port(config)? - }; - - assert!(self.allocated_ports.insert(number, port).is_none()); - Some(port) - } - - fn try_recover_port(&mut self, config: &Config) -> Option { - let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); - - let mut recovered_port = None; - - self.errored_ports = std::mem::take(&mut self.errored_ports) - .into_iter() - .filter_map(|(mut timestamp, port)| { - if recovered_port.is_none() - && now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME - { - println!( - " trying port: {port} at -{:?}", - Duration::from_secs(now.as_secs()) - .saturating_sub(Duration::from_secs(timestamp)) - ); - - match std::net::TcpListener::bind((config.listen_addr.ip(), port)) { - Ok(_) => { - recovered_port = Some((timestamp, port)); - return None; - } - Err(_) => timestamp = now.as_secs(), - } - } else { - println!( - "skipped port: {port} at -{:?}", - Duration::from_secs(now.as_secs()) - .saturating_sub(Duration::from_secs(timestamp)) - ); + if let Some(port_state) = port_handler.port_state.get_mut(&port) { + port_state.new_state(PortStatus::Disconnected); } - Some((timestamp, port)) - }) - .collect(); + if let Some(listener) = handler_metadata.listener.take() { + let res = port_handler.start_rejector( + port, + listener, + Packet { + header: Header { + kind: PacketKind::Reject.raw(), + length: 3, + }, + data: b"nc\0".to_vec(), + }, + ); - if let Some((_, port)) = recovered_port { - self.register_update(); - println!("recovered_port: {port}"); - return Some(port); - } + if let Err(err) = res { + println!( + "failed to start rejector on port {port} after client error: {err}" + ); + } + } + } - let removable_entry = self.allocated_ports.iter().find(|(_, port)| { - self.port_state - .get(port) - .map(|port_state| { - dbg!(port_state).status == PortStatus::Disconnected - && dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change))) - >= PORT_OWNERSHIP_TIMEOUT - }) - .unwrap_or(true) + sleep(Duration::from_secs(3)).await; + let _ = stream.shutdown().await; }); - - dbg!(&removable_entry); - - if let Some((&old_number, &port)) = removable_entry { - self.register_update(); - println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time"); - assert!(self.allocated_ports.remove(&old_number).is_some()); - return Some(port); - } - - None // TODO } - fn mark_port_error(&mut self, number: Number, port: Port) { - self.register_update(); - - self.errored_ports.insert(( - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(), - port, - )); - - self.allocated_ports.remove(&number); - self.free_ports.remove(&port); - } + Ok(()) } #[derive(Debug, Default)] @@ -649,109 +463,3 @@ async fn connection_handler( kind => bail!("unexpected packet: {:?}", kind), } } - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let config = Arc::new(Config::load("config.json")?); - - if config.allowed_ports.0.is_empty() { - panic!("no allowed ports"); - } - - let cache_path = PathBuf::from("cache.json"); - - let mut port_handler = PortHandler::load_or_default(&cache_path); - port_handler.update_allowed_ports(&config.allowed_ports); - - let port_handler = Arc::new(Mutex::new(port_handler)); - - { - let port_handler = port_handler.clone(); - tokio::spawn(async move { - let mut last_store = None; - loop { - sleep(Duration::from_secs(1)).await; - - let port_handler = port_handler.lock().await; - - if let Some(last_update) = port_handler.last_update { - let should_store = last_store - .map(|last_store| last_update > last_store) - .unwrap_or(true); - - if should_store { - last_store = Some(last_update); - port_handler.store(&cache_path).unwrap(); - } - } - } - }); - } - - let listener = TcpListener::bind(config.listen_addr).await?; - - while let Ok((mut stream, addr)) = listener.accept().await { - println!("connection from {addr}"); - - let port_handler = port_handler.clone(); - let config = config.clone(); - - let mut handler_metadata = HandlerMetadata::default(); - - tokio::spawn(async move { - let res = - connection_handler(&config, &mut handler_metadata, &port_handler, &mut stream) - .await; - - if let Err(err) = res { - println!("client at {addr} had an error: {err}"); - - let mut packet = Packet::default(); - - packet.data.extend_from_slice(err.to_string().as_bytes()); - packet.data.truncate(0xfe); - packet.data.push(0); - packet.header = Header { - kind: PacketKind::Error.raw(), - length: packet.data.len() as u8, - }; - - let (_, mut writer) = stream.split(); - let _ = packet.send(&mut writer).await; - } - - if let Some(port) = handler_metadata.port { - let mut port_handler = port_handler.lock().await; - - if let Some(port_state) = port_handler.port_state.get_mut(&port) { - port_state.new_state(PortStatus::Disconnected); - } - - if let Some(listener) = handler_metadata.listener.take() { - let res = port_handler.start_rejector( - port, - listener, - Packet { - header: Header { - kind: PacketKind::Reject.raw(), - length: 3, - }, - data: b"nc\0".to_vec(), - }, - ); - - if let Err(err) = res { - println!( - "failed to start rejector on port {port} after client error: {err}" - ); - } - } - } - - sleep(Duration::from_secs(3)).await; - let _ = stream.shutdown().await; - }); - } - - Ok(()) -} diff --git a/src/ports.rs b/src/ports.rs new file mode 100644 index 0000000..0029a57 --- /dev/null +++ b/src/ports.rs @@ -0,0 +1,318 @@ +use std::{ + collections::{BTreeSet, HashMap, HashSet}, + fmt::Debug, + fs::File, + io::{BufReader, BufWriter}, + ops::Range, + path::Path, + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use anyhow::anyhow; +use serde::{Deserialize, Serialize}; +use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle, time::Instant}; + +use crate::{ + packets::Packet, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME, +}; + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct PortHandler { + #[serde(skip)] + pub last_update: Option, + + #[serde(skip)] + port_guards: HashMap, + + allowed_ports: AllowedPorts, + + free_ports: HashSet, + errored_ports: BTreeSet<(UnixTimestamp, Port)>, + allocated_ports: HashMap, + + #[serde(skip)] + pub port_state: HashMap, +} + +#[derive(Default, Debug)] +pub struct PortState { + last_change: UnixTimestamp, + status: PortStatus, +} + +impl PortState { + pub fn new_state(&mut self, status: PortStatus) { + self.last_change = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + self.status = status; + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum PortStatus { + Disconnected, + Idle, + InCall, +} + +impl Default for PortStatus { + fn default() -> Self { + Self::Disconnected + } +} + +#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct AllowedPorts(Vec>); + +impl AllowedPorts { + pub fn is_allowed(&self, port: Port) -> bool { + self.0.iter().any(|range| range.contains(&port)) + } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl PortHandler { + pub fn register_update(&mut self) { + self.last_update = Some(Instant::now()); + } + + pub fn store(&self, cache: &Path) -> anyhow::Result<()> { + println!("storing database"); + serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?; + Ok(()) + } + + pub fn load(cache: &Path) -> std::io::Result { + println!("loading database"); + Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?) + } + + pub fn load_or_default(cache: &Path) -> Self { + Self::load(cache).unwrap_or_else(|err| { + println!("failed to parse cache file at {cache:?} using empty cache. error: {err}"); + Self::default() + }) + } + + pub fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) { + self.register_update(); + + self.allowed_ports = allowed_ports.clone(); + + self.free_ports.clear(); + self.free_ports + .extend(self.allowed_ports.0.iter().cloned().flatten()); + + self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach + + self.errored_ports + .retain(|(_, port)| self.allowed_ports.is_allowed(*port)); + + self.allocated_ports + .retain(|_, port| self.allowed_ports.is_allowed(*port)); + + self.free_ports.retain(|port| { + self.allocated_ports + .iter() + .find(|(_, allocated_port)| *allocated_port == port) + .is_none() + && self + .errored_ports + .iter() + .find(|(_, errored_port)| errored_port == port) + .is_none() + }); + } + + pub fn start_rejector( + &mut self, + port: Port, + listener: TcpListener, + packet: Packet, + ) -> anyhow::Result<()> { + println!("starting rejector: for port {port} with {packet:?}"); + + let port_guard = Rejector::start(listener, packet); + + assert!( + self.port_guards.insert(port, port_guard).is_none(), + "Tried to start rejector that is already running. + This should have been impossible since it requires two listeners on the same port." + ); + Ok(()) + } + + pub async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> { + println!("stopping rejector: for port {port}"); + + Some(self.port_guards.remove(&port)?.stop().await) + } + + pub async fn change_rejector( + &mut self, + port: Port, + f: impl FnOnce(&mut Packet), + ) -> anyhow::Result<()> { + let (listener, mut packet) = self + .stop_rejector(port) + .await + .ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?; + + f(&mut packet); + + self.start_rejector(port, listener, packet) + } +} + +struct Rejector { + state: Arc<(Mutex, Packet)>, + handle: JoinHandle<()>, +} + +impl Debug for Rejector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PortGuard").finish() + } +} + +impl Rejector { + fn start(listener: TcpListener, packet: Packet) -> Self { + let state = Arc::new((Mutex::new(listener), packet)); + + let handle = { + let state = state.clone(); + + tokio::spawn(async move { + let (listener, packet) = state.as_ref(); + + let listener = listener.lock().await; + + loop { + if let Ok((mut socket, _)) = listener.accept().await { + let (_, mut writer) = socket.split(); + let _ = packet.send(&mut writer).await; + } + } + }) + }; + + Self { state, handle } + } + async fn stop(self) -> (TcpListener, Packet) { + self.handle.abort(); + let _ = self.handle.await; + let (listener, packet) = Arc::try_unwrap(self.state).unwrap(); + (listener.into_inner(), packet) + } +} + +impl PortHandler { + pub fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option { + if let Some(port) = self.allocated_ports.get(&number) { + let already_connected = self + .port_state + .get(port) + .map(|state| state.status != PortStatus::Disconnected) + .unwrap_or(false); + + return if already_connected { None } else { Some(*port) }; + } + + let port = if let Some(&port) = self.free_ports.iter().next() { + self.register_update(); + self.free_ports.remove(&port); + port + } else { + self.try_recover_port(config)? + }; + + assert!(self.allocated_ports.insert(number, port).is_none()); + Some(port) + } + + fn try_recover_port(&mut self, config: &Config) -> Option { + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + + let mut recovered_port = None; + + self.errored_ports = std::mem::take(&mut self.errored_ports) + .into_iter() + .filter_map(|(mut timestamp, port)| { + if recovered_port.is_none() + && now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME + { + println!( + " trying port: {port} at -{:?}", + Duration::from_secs(now.as_secs()) + .saturating_sub(Duration::from_secs(timestamp)) + ); + + match std::net::TcpListener::bind((config.listen_addr.ip(), port)) { + Ok(_) => { + recovered_port = Some((timestamp, port)); + return None; + } + Err(_) => timestamp = now.as_secs(), + } + } else { + println!( + "skipped port: {port} at -{:?}", + Duration::from_secs(now.as_secs()) + .saturating_sub(Duration::from_secs(timestamp)) + ); + } + + Some((timestamp, port)) + }) + .collect(); + + if let Some((_, port)) = recovered_port { + self.register_update(); + println!("recovered_port: {port}"); + return Some(port); + } + + let removable_entry = self.allocated_ports.iter().find(|(_, port)| { + self.port_state + .get(port) + .map(|port_state| { + dbg!(port_state).status == PortStatus::Disconnected + && dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change))) + >= PORT_OWNERSHIP_TIMEOUT + }) + .unwrap_or(true) + }); + + dbg!(&removable_entry); + + if let Some((&old_number, &port)) = removable_entry { + self.register_update(); + println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time"); + assert!(self.allocated_ports.remove(&old_number).is_some()); + return Some(port); + } + + None // TODO + } + + pub fn mark_port_error(&mut self, number: Number, port: Port) { + self.register_update(); + + self.errored_ports.insert(( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + port, + )); + + self.allocated_ports.remove(&number); + self.free_ports.remove(&port); + } +}