diff --git a/.gitignore b/.gitignore index da1802e..9f9be56 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ /target -db.json +cache.json diff --git a/config.json b/config.json new file mode 100644 index 0000000..8ffc10d --- /dev/null +++ b/config.json @@ -0,0 +1,8 @@ +{ + "allowed_ports": [ + [ + 3000, + 3005 + ] + ] +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index dc71eab..e387466 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,38 +1,38 @@ -#![feature(generic_const_exprs)] -#![allow(unused)] +// #![allow(unused)] use std::{ - collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + collections::{BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, - future::Future, io::{BufReader, BufWriter}, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, SocketAddr}, ops::Range, path::{Path, PathBuf}, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; -use anyhow::bail; -use packets::{reject_static, Header, Packet, RemConnect}; +use anyhow::{anyhow, bail}; +use packets::{Header, Packet, RemConnect}; use serde::{Deserialize, Serialize}; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, TcpSocket, TcpStream}, + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, select, + sync::Mutex, task::JoinHandle, - time::Instant, + time::{sleep, Instant}, }; -use crate::packets::dyn_ip_update; +use crate::packets::{dyn_ip_update, PacketKind, REJECT_OOP, REJECT_TIMEOUT}; const AUTH_TIMEOUT: Duration = Duration::from_secs(30); const CALL_ACK_TIMEOUT: Duration = Duration::from_secs(30); -const PING_INTERVAL: Duration = Duration::from_secs(15); -const TIMEOUT_DELAY: Duration = Duration::from_secs(35); -const PORT_TIMEOUT: Duration = Duration::from_secs(60); -const PORT_RETRY_TIME: Duration = Duration::from_secs(60); // 10 * +const CALL_TIMEOUT: Duration = Duration::from_secs(24 * 60 * 60); +const PORT_RETRY_TIME: Duration = Duration::from_secs(15 * 60); +const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60); +const PING_TIMEOUT: Duration = Duration::from_secs(30); +const SEND_PING_INTERVAL: Duration = Duration::from_secs(30); const BIND_IP: &str = "0.0.0.0"; @@ -48,14 +48,14 @@ struct Config { } impl Config { - fn load(db: &Path) -> std::io::Result { + fn load(cache: &Path) -> std::io::Result { println!("loading config"); - Ok(serde_json::from_reader(BufReader::new(File::open(db)?))?) + Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?) } - fn load_or_default(db: &Path) -> std::io::Result { - match Self::load(db) { - Ok(db) => Ok(db), + fn load_or_default(cache: &Path) -> std::io::Result { + match Self::load(cache) { + Ok(cache) => Ok(cache), Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(Self::default()), Err(err) => Err(err), } @@ -75,11 +75,40 @@ struct PortHandler { free_ports: HashSet, errored_ports: BTreeSet<(UnixTimestamp, Port)>, allocated_ports: HashMap, - port_status: HashMap, + + #[serde(skip)] + port_state: HashMap, } -#[derive(Debug, Serialize, Deserialize)] -struct PortStatus {} +#[derive(Default, Debug)] +struct PortState { + last_change: UnixTimestamp, + status: PortStatus, +} + +impl PortState { + fn new_state(&mut self, status: PortStatus) { + self.last_change = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + self.status = status; + } +} + +#[derive(Debug, PartialEq, Eq)] +enum PortStatus { + Disconnected, + Idle, + InCall, +} + +impl Default for PortStatus { + fn default() -> Self { + Self::Disconnected + } +} #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] struct AllowedPorts(Vec>); @@ -95,23 +124,19 @@ impl PortHandler { self.last_update = Some(Instant::now()); } - fn store(&self, db: &Path) -> anyhow::Result<()> { + fn store(&self, cache: &Path) -> anyhow::Result<()> { println!("storing database"); - serde_json::to_writer(BufWriter::new(File::create(db)?), self)?; + serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?; Ok(()) } - fn load(db: &Path) -> std::io::Result { + fn load(cache: &Path) -> std::io::Result { println!("loading database"); - Ok(serde_json::from_reader(BufReader::new(File::open(db)?))?) + Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?) } - fn load_or_default(db: &Path) -> std::io::Result { - match Self::load(db) { - Ok(db) => Ok(db), - Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(Self::default()), - Err(err) => Err(err), - } + fn load_or_default(cache: &Path) -> Self { + Self::load(cache).unwrap_or(Self::default()) } fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) { @@ -144,37 +169,48 @@ impl PortHandler { }); } - fn start_port_guard<'fut, Fut, Func>(&mut self, port: Port, listener: TcpListener, f: Func) - where - Fut: Future + Send + 'fut, - Func: FnOnce(&'_ mut TcpListener) -> Fut + Send + 'static, - { - assert!(self - .port_guards - .insert(port, PortGuard::start(listener, f)) - .is_none()); + fn start_rejector( + &mut self, + port: Port, + listener: TcpListener, + packet: Packet, + ) -> anyhow::Result<()> { + println!("starting rejector: for port {port} with packet {packet:?}"); + + let port_guard = PortGuard::start(listener, packet); + + assert!( + self.port_guards.insert(port, port_guard).is_none(), + "Tried to start rejector that is already running. + This should have been impossible since it requires two listeners on the same port." + ); + Ok(()) } - fn start_rejector(&mut self, port: Port, listener: TcpListener, packet: Packet) { - assert!(self - .port_guards - .insert( - port, - PortGuard::start(listener, move |listener: &mut TcpListener| async move { - loop { - if let Ok((mut socket, _)) = listener.accept().await { - let (_, mut writer) = socket.split(); - let _ = packet.send(&mut writer).await; - } - } - }) - ) - .is_none()); + async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> { + println!("stopping rejector: for port {port}"); + + Some(self.port_guards.remove(&port)?.stop().await) + } + + async fn change_rejector( + &mut self, + port: Port, + f: impl FnOnce(&mut Packet), + ) -> anyhow::Result<()> { + let (listener, mut packet) = self + .stop_rejector(port) + .await + .ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?; + + f(&mut packet); + + self.start_rejector(port, listener, packet) } } struct PortGuard { - listener: Arc>, + state: Arc<(Mutex, Packet)>, handle: JoinHandle<()>, } @@ -185,38 +221,46 @@ impl Debug for PortGuard { } impl PortGuard { - fn start<'fut, Fut>( - listener: TcpListener, - f: impl FnOnce(&mut TcpListener) -> Fut + Send + 'static, - ) -> Self - where - Fut: Future + Send + 'fut, - { - let mut listener = Arc::new(tokio::sync::Mutex::new(listener)); + fn start(listener: TcpListener, packet: Packet) -> Self { + let state = Arc::new((Mutex::new(listener), packet)); let handle = { - let listener = listener.clone(); + let state = state.clone(); tokio::spawn(async move { - let mut lock = listener.lock().await; - f(&mut *lock).await; + let (listener, packet) = state.as_ref(); + + let listener = listener.lock().await; + + loop { + if let Ok((mut socket, _)) = listener.accept().await { + let (_, mut writer) = socket.split(); + let _ = packet.send(&mut writer).await; + } + } }) }; - Self { listener, handle } + Self { state, handle } } - - async fn stop(mut self) -> TcpListener { + async fn stop(self) -> (TcpListener, Packet) { self.handle.abort(); let _ = self.handle.await; - Arc::try_unwrap(self.listener).unwrap().into_inner() + let (listener, packet) = Arc::try_unwrap(self.state).unwrap(); + (listener.into_inner(), packet) } } impl PortHandler { fn allocate_port_for_number(&mut self, number: Number) -> Option { if let Some(port) = self.allocated_ports.get(&number) { - return Some(*port); + let already_connected = self + .port_state + .get(port) + .map(|state| state.status != PortStatus::Disconnected) + .unwrap_or(false); + + return if already_connected { None } else { Some(*port) }; } let port = if let Some(&port) = self.free_ports.iter().next() { @@ -238,7 +282,7 @@ impl PortHandler { self.errored_ports = std::mem::take(&mut self.errored_ports) .into_iter() - .filter_map(|(mut timestamp, mut port)| { + .filter_map(|(mut timestamp, port)| { if recovered_port.is_none() && now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME { @@ -268,10 +312,31 @@ impl PortHandler { .collect(); if let Some((_, port)) = recovered_port { + self.register_update(); println!("recovered_port: {port}"); return Some(port); } + let removable_entry = self.allocated_ports.iter().find(|(_, port)| { + self.port_state + .get(port) + .map(|port_state| { + dbg!(port_state).status == PortStatus::Disconnected + && dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change))) + >= PORT_OWNERSHIP_TIMEOUT + }) + .unwrap_or(true) + }); + + dbg!(&removable_entry); + + if let Some((&old_number, &port)) = removable_entry { + self.register_update(); + println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time"); + assert!(self.allocated_ports.remove(&old_number).is_some()); + return Some(port); + } + None // TODO } @@ -289,92 +354,241 @@ impl PortHandler { self.allocated_ports.remove(&number); self.free_ports.remove(&port); } +} - fn open_port(&mut self, port: Port) -> Option { - todo!() - } - - fn close_port_for(&mut self, number: Number, listener: TcpListener) -> anyhow::Result<()> { - todo!() - } +#[derive(Debug, Default)] +struct HandlerMetadata { + number: Option, + port: Option, } async fn connection_handler( - port_handler: Arc>, + handler_metadata: &mut HandlerMetadata, + port_handler: &Mutex, stream: &mut TcpStream, ) -> anyhow::Result<()> { let (mut reader, mut writer) = stream.split(); - let mut packet = Packet::recv(&mut reader).await?; + let mut packet = Packet::default(); + + select! { + res = packet.recv_into_cancelation_safe(&mut reader) => res?, + _ = sleep(AUTH_TIMEOUT) => { + writer.write_all(REJECT_TIMEOUT).await?; + return Ok(()); + } + } let RemConnect { number, pin } = packet.as_rem_connect()?; + handler_metadata.number = Some(number); + + let mut authenticated = false; let (port, listener) = loop { - let port = port_handler - .lock() - .unwrap() - .allocate_port_for_number(number); + let mut updated_server = false; + + let port = port_handler.lock().await.allocate_port_for_number(number); println!("allocated port: {:?}", port); let Some(port) = port else { - writer.write_all(&reject_static(b"oop")).await?; + writer.write_all(REJECT_OOP).await?; return Ok(()); }; - let ip = dyn_ip_update(number, pin, port).await?; + if !authenticated { + let _ip = dyn_ip_update(number, pin, port).await?; + authenticated = true; + updated_server = true; + } - let listener = TcpListener::bind((BIND_IP, port)).await; + let mut port_handler = port_handler.lock().await; - let listener = match listener { - Ok(listener) => break (port, listener), - Err(err) => { - port_handler.lock().unwrap().mark_port_error(number, port); - // tokio::time::sleep(Duration::from_millis(300)).await; + let listener = if let Some((listener, _package)) = port_handler.stop_rejector(port).await { + Ok(listener) + } else { + TcpListener::bind((BIND_IP, port)).await + }; + + match listener { + Ok(listener) => { + if !updated_server { + let _ip = dyn_ip_update(number, pin, port).await?; + } + + port_handler + .port_state + .entry(port) + .or_default() + .new_state(PortStatus::Idle); + + handler_metadata.port = Some(port); + + break (port, listener); + } + Err(_err) => { + port_handler.mark_port_error(number, port); continue; } }; }; #[derive(Debug)] - enum Foo { - Caller { stream: TcpStream, addr: SocketAddr }, - Packet { packet: Packet }, + enum Result { + Caller { + packet: Packet, + stream: TcpStream, + addr: SocketAddr, + }, + Packet { + packet: Packet, + }, } - let result = select! { - kind = Packet::peek_packet_kind(&mut reader) => { - packet.recv_into(&mut reader).await?; - Foo::Packet { packet } - }, - caller = listener.accept() => { - let (stream, addr) = caller?; - Foo::Caller { stream, addr } - }, + let mut last_ping_sent_at = Instant::now(); + let mut last_ping_received_at = Instant::now(); + + let result = loop { + let now = Instant::now(); + + select! { + caller = listener.accept() => { + let (stream, addr) = caller?; + break Result::Caller { packet, stream, addr } + }, + _ = Packet::peek_packet_kind(&mut reader) => { + packet.recv_into(&mut reader).await?; + + if packet.kind() == PacketKind::Ping { + last_ping_received_at = now; + } else { + break Result::Packet { packet } + } + }, + _ = sleep(now.saturating_duration_since(last_ping_sent_at).saturating_sub(SEND_PING_INTERVAL)) => { + writer.write_all(bytemuck::bytes_of(& Header { kind: PacketKind::Ping.raw(), length: 0 })).await?; + last_ping_sent_at = now; + } + _ = sleep(now.saturating_duration_since(last_ping_received_at).saturating_sub(PING_TIMEOUT)) => { + writer.write_all(REJECT_TIMEOUT).await?; + return Ok(()); + } + } }; - dbg!(&result); + let (mut client, mut packet) = match result { + Result::Packet { packet } => { + if matches!( + packet.kind(), + packets::PacketKind::End | packets::PacketKind::Reject + ) { + println!("got disconnect packet: {packet:?}"); - match result { - Foo::Caller { stream, addr } => todo!(), - Foo::Packet { mut packet } => { - match packet.kind() { - packets::PacketKind::End => { - packet.header = Header { kind: 3, length: 0 }; - packet.data.clear(); - } - packets::PacketKind::Reject => {} - - kind => bail!("unexpected packet: {kind:?}"), + port_handler + .lock() + .await + .start_rejector(port, listener, packet)?; + return Ok(()); + } else { + bail!("unexpected packet: {:?}", packet.kind()) } - port_handler - .lock() - .unwrap() - .start_rejector(port, listener, packet); + } + Result::Caller { + mut packet, + stream, + addr, + } => { + println!("got caller from: {addr}"); + + packet.data.clear(); + match addr.ip() { + IpAddr::V4(addr) => packet.data.extend_from_slice(&addr.octets()), + IpAddr::V6(addr) => packet.data.extend_from_slice(&addr.octets()), + } + packet.header = Header { + kind: PacketKind::RemCall.raw(), + length: packet.data.len() as u8, + }; + + packet.send(&mut writer).await?; + + (stream, packet) + } + }; + + select! { + res = packet.recv_into_cancelation_safe(&mut reader) => res?, + _ = sleep(CALL_ACK_TIMEOUT) => { + writer.write_all(REJECT_TIMEOUT).await?; + return Ok(()); } } - Ok(()) + match packet.kind() { + PacketKind::End | PacketKind::Reject => { + port_handler + .lock() + .await + .start_rejector(port, listener, packet)?; + + return Ok(()); + } + + PacketKind::RemAck => { + packet.header = Header { + kind: PacketKind::Reject.raw(), + length: 4, + }; + packet.data.clear(); + packet.data.extend_from_slice(b"occ"); + packet.data.push(0); + + { + let mut port_handler = port_handler.lock().await; + + port_handler.register_update(); + port_handler + .port_state + .entry(port) + .or_default() + .new_state(PortStatus::InCall); + + port_handler.start_rejector(port, listener, packet)?; + } + + select! { + _ = tokio::io::copy_bidirectional(stream, &mut client) => {} + _ = sleep(CALL_TIMEOUT) => {} + } + + { + let mut port_handler = port_handler.lock().await; + + port_handler.register_update(); + port_handler + .port_state + .entry(port) + .or_default() + .new_state(PortStatus::Disconnected); + + port_handler + .change_rejector(port, |packet| { + packet.data.clear(); + packet.data.extend_from_slice(b"nc"); + packet.data.push(0); + packet.header = Header { + kind: PacketKind::Reject.raw(), + length: packet.data.len() as u8, + }; + }) + .await?; + } + + return Ok(()); + } + + kind => bail!("unexpected packet: {:?}", kind), + } } #[tokio::main] @@ -385,9 +599,9 @@ async fn main() -> anyhow::Result<()> { panic!("no allowed ports"); } - let db_path = PathBuf::from("db.json"); + let cache_path = PathBuf::from("cache.json"); - let mut port_handler = PortHandler::load_or_default(&db_path)?; + let mut port_handler = PortHandler::load_or_default(&cache_path); port_handler.update_allowed_ports(&config.allowed_ports); let port_handler = Arc::new(Mutex::new(port_handler)); @@ -397,9 +611,9 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let mut last_store = None; loop { - tokio::time::sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; - let port_handler = port_handler.lock().unwrap(); + let port_handler = port_handler.lock().await; if let Some(last_update) = port_handler.last_update { let should_store = last_store @@ -408,22 +622,26 @@ async fn main() -> anyhow::Result<()> { if should_store { last_store = Some(last_update); - port_handler.store(&db_path).unwrap(); + port_handler.store(&cache_path).unwrap(); } } } }); } - let listener = TcpListener::bind(("127.0.0.1", 11812)).await?; + let listener = TcpListener::bind(("0.0.0.0", 11820)).await?; while let Ok((mut stream, addr)) = listener.accept().await { println!("connection from {addr}"); let port_handler = port_handler.clone(); + let mut handler_metadata = HandlerMetadata::default(); + tokio::spawn(async move { - if let Err(err) = connection_handler(port_handler, &mut stream).await { + let res = connection_handler(&mut handler_metadata, &port_handler, &mut stream).await; + + if let Err(err) = res { println!("client at {addr} had an error: {err}"); let mut packet = Packet::default(); @@ -432,13 +650,26 @@ async fn main() -> anyhow::Result<()> { packet.data.truncate(0xfe); packet.data.push(0); packet.header = Header { - kind: 0xff, + kind: PacketKind::Error.raw(), length: packet.data.len() as u8, }; let (_, mut writer) = stream.split(); let _ = packet.send(&mut writer).await; } + + // if let Some(number) = handler_metadata.number { + // + // } + + if let Some(port) = handler_metadata.port { + if let Some(port_state) = port_handler.lock().await.port_state.get_mut(&port) { + port_state.new_state(PortStatus::Disconnected); + } + } + + sleep(Duration::from_secs(3)).await; + let _ = stream.shutdown().await; }); } diff --git a/src/packets.rs b/src/packets.rs index 33dbc40..4ec2619 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -1,5 +1,3 @@ -use std::{ffi::CString, mem::discriminant}; - use anyhow::bail; use bytemuck::{Pod, Zeroable}; use tokio::{ @@ -7,25 +5,13 @@ use tokio::{ net::tcp::{ReadHalf, WriteHalf}, }; -pub const fn reject_static(message: &[u8; N]) -> [u8; N + 2] { - let mut pkg = [0u8; N + 2]; - pkg[0] = 4; - pkg[1] = message.len() as u8; - let mut i = 0; - while i < message.len() { - pkg[i + 2] = message[i]; - i += 1; - } - pkg -} - -pub const REJECT_OCC: &[u8; 6] = b"\x04\x04occ\x00"; -pub const REJECT_NC: &[u8; 5] = b"\x04\x03nc\x00"; +pub const REJECT_OOP: &[u8; 6] = b"\x04\x04oop\x00"; +pub const REJECT_TIMEOUT: &[u8; 10] = b"\x04\x08timeout\x00"; #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PacketKind { - Unknown(u8), + Ping = 0x00, DynIpUpdate = 0x01, DynIpUpdateResponse = 0x02, End = 0x03, @@ -34,6 +20,7 @@ pub enum PacketKind { RemConfirm = 0x82, RemCall = 0x83, RemAck = 0x84, + Unknown(u8), Error = 0xff, } @@ -42,6 +29,7 @@ impl PacketKind { use PacketKind::*; match raw { + 0x00 => Ping, 0x01 => DynIpUpdate, 0x02 => DynIpUpdateResponse, 0x03 => End, @@ -55,11 +43,11 @@ impl PacketKind { } } - fn kind(&self) -> u8 { + pub fn raw(&self) -> u8 { use PacketKind::*; match self { - Unknown(value) => *value, + Ping => 0, DynIpUpdate => 0x01, DynIpUpdateResponse => 0x02, End => 0x03, @@ -69,6 +57,8 @@ impl PacketKind { RemCall => 0x83, RemAck => 0x84, Error => 0xff, + + Unknown(value) => *value, } } } @@ -111,10 +101,18 @@ impl Packet { } } - pub async fn recv(stream: &mut ReadHalf<'_>) -> std::io::Result { - let mut packet = Packet::default(); - packet.recv_into(stream).await?; - Ok(packet) + pub async fn recv_into_cancelation_safe( + &mut self, + stream: &mut ReadHalf<'_>, + ) -> std::io::Result<()> { + // Makes sure all data is available before reading + let header_bytes = bytemuck::bytes_of_mut(&mut self.header); + stream.peek(header_bytes).await?; + self.data.resize(self.header.length as usize + 2, 0); + stream.peek(&mut self.data).await?; + + // All data is available. Read the data + self.recv_into(stream).await } pub async fn recv_into(&mut self, stream: &mut ReadHalf<'_>) -> std::io::Result<()> { @@ -161,7 +159,7 @@ impl Packet { pub async fn dyn_ip_update(number: u32, pin: u16, port: u16) -> anyhow::Result { let mut packet = Packet::default(); packet.header = Header { - kind: PacketKind::DynIpUpdate.kind(), + kind: PacketKind::DynIpUpdate.raw(), length: 8, }; @@ -171,7 +169,8 @@ pub async fn dyn_ip_update(number: u32, pin: u16, port: u16) -> anyhow::Result