restructure code
This commit is contained in:
parent
d588b26d38
commit
4870dafbaa
@ -6,8 +6,12 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread", "net", "io-util", "sync", "time"] }
|
||||
anyhow = { version = "1.0.68", features = ["backtrace"] }
|
||||
bytemuck = { version = "1.13.0", features = ["derive"] }
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = "1.0.91"
|
||||
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread", "net", "io-util", "sync", "time"] }
|
||||
|
||||
[features]
|
||||
default = ["debug_server"]
|
||||
debug_server = []
|
1
src/debug_server.rs
Normal file
1
src/debug_server.rs
Normal file
@ -0,0 +1 @@
|
||||
pub async fn debug_server() {}
|
466
src/main.rs
466
src/main.rs
@ -1,30 +1,29 @@
|
||||
// #![allow(unused)]
|
||||
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
fs::File,
|
||||
io::{BufReader, BufWriter},
|
||||
io::BufReader,
|
||||
net::{SocketAddr, ToSocketAddrs},
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use anyhow::{bail, Context};
|
||||
use debug_server::debug_server;
|
||||
use packets::{Header, Packet, RemConnect};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
net::{TcpListener, TcpStream},
|
||||
select,
|
||||
sync::Mutex,
|
||||
task::JoinHandle,
|
||||
time::{sleep, Instant},
|
||||
};
|
||||
|
||||
use crate::packets::{dyn_ip_update, PacketKind, REJECT_OOP, REJECT_TIMEOUT};
|
||||
use crate::ports::{AllowedPorts, PortHandler, PortStatus};
|
||||
|
||||
const AUTH_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const CALL_ACK_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
@ -34,14 +33,17 @@ const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const SEND_PING_INTERVAL: Duration = Duration::from_secs(20);
|
||||
|
||||
#[cfg(feature = "debug_server")]
|
||||
mod debug_server;
|
||||
mod packets;
|
||||
mod ports;
|
||||
|
||||
type Port = u16;
|
||||
type Number = u32;
|
||||
type UnixTimestamp = u64;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
pub struct Config {
|
||||
allowed_ports: AllowedPorts,
|
||||
#[serde(deserialize_with = "parse_socket_addr")]
|
||||
listen_addr: SocketAddr,
|
||||
@ -68,301 +70,113 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize)]
|
||||
struct PortHandler {
|
||||
#[serde(skip)]
|
||||
last_update: Option<Instant>,
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = Arc::new(Config::load("config.json")?);
|
||||
|
||||
#[serde(skip)]
|
||||
port_guards: HashMap<Port, PortGuard>,
|
||||
|
||||
allowed_ports: AllowedPorts,
|
||||
|
||||
free_ports: HashSet<Port>,
|
||||
errored_ports: BTreeSet<(UnixTimestamp, Port)>,
|
||||
allocated_ports: HashMap<Number, Port>,
|
||||
|
||||
#[serde(skip)]
|
||||
port_state: HashMap<Port, PortState>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct PortState {
|
||||
last_change: UnixTimestamp,
|
||||
status: PortStatus,
|
||||
}
|
||||
|
||||
impl PortState {
|
||||
fn new_state(&mut self, status: PortStatus) {
|
||||
self.last_change = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
self.status = status;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum PortStatus {
|
||||
Disconnected,
|
||||
Idle,
|
||||
InCall,
|
||||
}
|
||||
|
||||
impl Default for PortStatus {
|
||||
fn default() -> Self {
|
||||
Self::Disconnected
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
struct AllowedPorts(Vec<Range<u16>>);
|
||||
|
||||
impl AllowedPorts {
|
||||
fn is_allowed(&self, port: Port) -> bool {
|
||||
self.0.iter().any(|range| range.contains(&port))
|
||||
}
|
||||
}
|
||||
|
||||
impl PortHandler {
|
||||
fn register_update(&mut self) {
|
||||
self.last_update = Some(Instant::now());
|
||||
if config.allowed_ports.is_empty() {
|
||||
panic!("no allowed ports");
|
||||
}
|
||||
|
||||
fn store(&self, cache: &Path) -> anyhow::Result<()> {
|
||||
println!("storing database");
|
||||
serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?;
|
||||
Ok(())
|
||||
}
|
||||
let cache_path = PathBuf::from("cache.json");
|
||||
|
||||
fn load(cache: &Path) -> std::io::Result<Self> {
|
||||
println!("loading database");
|
||||
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
|
||||
}
|
||||
let mut port_handler = PortHandler::load_or_default(&cache_path);
|
||||
port_handler.update_allowed_ports(&config.allowed_ports);
|
||||
|
||||
fn load_or_default(cache: &Path) -> Self {
|
||||
Self::load(cache).unwrap_or_else(|err| {
|
||||
println!("failed to parse cache file at {cache:?} using empty cache. error: {err}");
|
||||
Self::default()
|
||||
})
|
||||
}
|
||||
let port_handler = Arc::new(Mutex::new(port_handler));
|
||||
|
||||
fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) {
|
||||
self.register_update();
|
||||
{
|
||||
let port_handler = port_handler.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut last_store = None;
|
||||
loop {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
self.allowed_ports = allowed_ports.clone();
|
||||
let port_handler = port_handler.lock().await;
|
||||
|
||||
self.free_ports.clear();
|
||||
self.free_ports
|
||||
.extend(self.allowed_ports.0.iter().cloned().flatten());
|
||||
if let Some(last_update) = port_handler.last_update {
|
||||
let should_store = last_store
|
||||
.map(|last_store| last_update > last_store)
|
||||
.unwrap_or(true);
|
||||
|
||||
self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach
|
||||
|
||||
self.errored_ports
|
||||
.retain(|(_, port)| self.allowed_ports.is_allowed(*port));
|
||||
|
||||
self.allocated_ports
|
||||
.retain(|_, port| self.allowed_ports.is_allowed(*port));
|
||||
|
||||
self.free_ports.retain(|port| {
|
||||
self.allocated_ports
|
||||
.iter()
|
||||
.find(|(_, allocated_port)| *allocated_port == port)
|
||||
.is_none()
|
||||
&& self
|
||||
.errored_ports
|
||||
.iter()
|
||||
.find(|(_, errored_port)| errored_port == port)
|
||||
.is_none()
|
||||
if should_store {
|
||||
last_store = Some(last_update);
|
||||
port_handler.store(&cache_path).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn start_rejector(
|
||||
&mut self,
|
||||
port: Port,
|
||||
listener: TcpListener,
|
||||
packet: Packet,
|
||||
) -> anyhow::Result<()> {
|
||||
println!("starting rejector: for port {port} with {packet:?}");
|
||||
#[cfg(feature = "debug_server")]
|
||||
tokio::spawn(debug_server());
|
||||
|
||||
let port_guard = PortGuard::start(listener, packet);
|
||||
let listener = TcpListener::bind(config.listen_addr).await?;
|
||||
|
||||
assert!(
|
||||
self.port_guards.insert(port, port_guard).is_none(),
|
||||
"Tried to start rejector that is already running.
|
||||
This should have been impossible since it requires two listeners on the same port."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
while let Ok((mut stream, addr)) = listener.accept().await {
|
||||
println!("connection from {addr}");
|
||||
|
||||
async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> {
|
||||
println!("stopping rejector: for port {port}");
|
||||
let port_handler = port_handler.clone();
|
||||
let config = config.clone();
|
||||
|
||||
Some(self.port_guards.remove(&port)?.stop().await)
|
||||
}
|
||||
let mut handler_metadata = HandlerMetadata::default();
|
||||
|
||||
async fn change_rejector(
|
||||
&mut self,
|
||||
port: Port,
|
||||
f: impl FnOnce(&mut Packet),
|
||||
) -> anyhow::Result<()> {
|
||||
let (listener, mut packet) = self
|
||||
.stop_rejector(port)
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?;
|
||||
tokio::spawn(async move {
|
||||
let res =
|
||||
connection_handler(&config, &mut handler_metadata, &port_handler, &mut stream)
|
||||
.await;
|
||||
|
||||
f(&mut packet);
|
||||
if let Err(err) = res {
|
||||
println!("client at {addr} had an error: {err}");
|
||||
|
||||
self.start_rejector(port, listener, packet)
|
||||
}
|
||||
}
|
||||
let mut packet = Packet::default();
|
||||
|
||||
struct PortGuard {
|
||||
state: Arc<(Mutex<TcpListener>, Packet)>,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
packet.data.extend_from_slice(err.to_string().as_bytes());
|
||||
packet.data.truncate(0xfe);
|
||||
packet.data.push(0);
|
||||
packet.header = Header {
|
||||
kind: PacketKind::Error.raw(),
|
||||
length: packet.data.len() as u8,
|
||||
};
|
||||
|
||||
impl Debug for PortGuard {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PortGuard").finish()
|
||||
}
|
||||
}
|
||||
let (_, mut writer) = stream.split();
|
||||
let _ = packet.send(&mut writer).await;
|
||||
}
|
||||
|
||||
impl PortGuard {
|
||||
fn start(listener: TcpListener, packet: Packet) -> Self {
|
||||
let state = Arc::new((Mutex::new(listener), packet));
|
||||
if let Some(port) = handler_metadata.port {
|
||||
let mut port_handler = port_handler.lock().await;
|
||||
|
||||
let handle = {
|
||||
let state = state.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (listener, packet) = state.as_ref();
|
||||
|
||||
let listener = listener.lock().await;
|
||||
|
||||
loop {
|
||||
if let Ok((mut socket, _)) = listener.accept().await {
|
||||
let (_, mut writer) = socket.split();
|
||||
let _ = packet.send(&mut writer).await;
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Self { state, handle }
|
||||
}
|
||||
async fn stop(self) -> (TcpListener, Packet) {
|
||||
self.handle.abort();
|
||||
let _ = self.handle.await;
|
||||
let (listener, packet) = Arc::try_unwrap(self.state).unwrap();
|
||||
(listener.into_inner(), packet)
|
||||
}
|
||||
}
|
||||
|
||||
impl PortHandler {
|
||||
fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option<Port> {
|
||||
if let Some(port) = self.allocated_ports.get(&number) {
|
||||
let already_connected = self
|
||||
.port_state
|
||||
.get(port)
|
||||
.map(|state| state.status != PortStatus::Disconnected)
|
||||
.unwrap_or(false);
|
||||
|
||||
return if already_connected { None } else { Some(*port) };
|
||||
}
|
||||
|
||||
let port = if let Some(&port) = self.free_ports.iter().next() {
|
||||
self.register_update();
|
||||
self.free_ports.remove(&port);
|
||||
port
|
||||
} else {
|
||||
self.try_recover_port(config)?
|
||||
};
|
||||
|
||||
assert!(self.allocated_ports.insert(number, port).is_none());
|
||||
Some(port)
|
||||
}
|
||||
|
||||
fn try_recover_port(&mut self, config: &Config) -> Option<Port> {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
|
||||
let mut recovered_port = None;
|
||||
|
||||
self.errored_ports = std::mem::take(&mut self.errored_ports)
|
||||
.into_iter()
|
||||
.filter_map(|(mut timestamp, port)| {
|
||||
if recovered_port.is_none()
|
||||
&& now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME
|
||||
{
|
||||
println!(
|
||||
" trying port: {port} at -{:?}",
|
||||
Duration::from_secs(now.as_secs())
|
||||
.saturating_sub(Duration::from_secs(timestamp))
|
||||
);
|
||||
|
||||
match std::net::TcpListener::bind((config.listen_addr.ip(), port)) {
|
||||
Ok(_) => {
|
||||
recovered_port = Some((timestamp, port));
|
||||
return None;
|
||||
}
|
||||
Err(_) => timestamp = now.as_secs(),
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"skipped port: {port} at -{:?}",
|
||||
Duration::from_secs(now.as_secs())
|
||||
.saturating_sub(Duration::from_secs(timestamp))
|
||||
);
|
||||
if let Some(port_state) = port_handler.port_state.get_mut(&port) {
|
||||
port_state.new_state(PortStatus::Disconnected);
|
||||
}
|
||||
|
||||
Some((timestamp, port))
|
||||
})
|
||||
.collect();
|
||||
if let Some(listener) = handler_metadata.listener.take() {
|
||||
let res = port_handler.start_rejector(
|
||||
port,
|
||||
listener,
|
||||
Packet {
|
||||
header: Header {
|
||||
kind: PacketKind::Reject.raw(),
|
||||
length: 3,
|
||||
},
|
||||
data: b"nc\0".to_vec(),
|
||||
},
|
||||
);
|
||||
|
||||
if let Some((_, port)) = recovered_port {
|
||||
self.register_update();
|
||||
println!("recovered_port: {port}");
|
||||
return Some(port);
|
||||
}
|
||||
if let Err(err) = res {
|
||||
println!(
|
||||
"failed to start rejector on port {port} after client error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let removable_entry = self.allocated_ports.iter().find(|(_, port)| {
|
||||
self.port_state
|
||||
.get(port)
|
||||
.map(|port_state| {
|
||||
dbg!(port_state).status == PortStatus::Disconnected
|
||||
&& dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change)))
|
||||
>= PORT_OWNERSHIP_TIMEOUT
|
||||
})
|
||||
.unwrap_or(true)
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
|
||||
dbg!(&removable_entry);
|
||||
|
||||
if let Some((&old_number, &port)) = removable_entry {
|
||||
self.register_update();
|
||||
println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time");
|
||||
assert!(self.allocated_ports.remove(&old_number).is_some());
|
||||
return Some(port);
|
||||
}
|
||||
|
||||
None // TODO
|
||||
}
|
||||
|
||||
fn mark_port_error(&mut self, number: Number, port: Port) {
|
||||
self.register_update();
|
||||
|
||||
self.errored_ports.insert((
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
port,
|
||||
));
|
||||
|
||||
self.allocated_ports.remove(&number);
|
||||
self.free_ports.remove(&port);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@ -649,109 +463,3 @@ async fn connection_handler(
|
||||
kind => bail!("unexpected packet: {:?}", kind),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = Arc::new(Config::load("config.json")?);
|
||||
|
||||
if config.allowed_ports.0.is_empty() {
|
||||
panic!("no allowed ports");
|
||||
}
|
||||
|
||||
let cache_path = PathBuf::from("cache.json");
|
||||
|
||||
let mut port_handler = PortHandler::load_or_default(&cache_path);
|
||||
port_handler.update_allowed_ports(&config.allowed_ports);
|
||||
|
||||
let port_handler = Arc::new(Mutex::new(port_handler));
|
||||
|
||||
{
|
||||
let port_handler = port_handler.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut last_store = None;
|
||||
loop {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let port_handler = port_handler.lock().await;
|
||||
|
||||
if let Some(last_update) = port_handler.last_update {
|
||||
let should_store = last_store
|
||||
.map(|last_store| last_update > last_store)
|
||||
.unwrap_or(true);
|
||||
|
||||
if should_store {
|
||||
last_store = Some(last_update);
|
||||
port_handler.store(&cache_path).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind(config.listen_addr).await?;
|
||||
|
||||
while let Ok((mut stream, addr)) = listener.accept().await {
|
||||
println!("connection from {addr}");
|
||||
|
||||
let port_handler = port_handler.clone();
|
||||
let config = config.clone();
|
||||
|
||||
let mut handler_metadata = HandlerMetadata::default();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let res =
|
||||
connection_handler(&config, &mut handler_metadata, &port_handler, &mut stream)
|
||||
.await;
|
||||
|
||||
if let Err(err) = res {
|
||||
println!("client at {addr} had an error: {err}");
|
||||
|
||||
let mut packet = Packet::default();
|
||||
|
||||
packet.data.extend_from_slice(err.to_string().as_bytes());
|
||||
packet.data.truncate(0xfe);
|
||||
packet.data.push(0);
|
||||
packet.header = Header {
|
||||
kind: PacketKind::Error.raw(),
|
||||
length: packet.data.len() as u8,
|
||||
};
|
||||
|
||||
let (_, mut writer) = stream.split();
|
||||
let _ = packet.send(&mut writer).await;
|
||||
}
|
||||
|
||||
if let Some(port) = handler_metadata.port {
|
||||
let mut port_handler = port_handler.lock().await;
|
||||
|
||||
if let Some(port_state) = port_handler.port_state.get_mut(&port) {
|
||||
port_state.new_state(PortStatus::Disconnected);
|
||||
}
|
||||
|
||||
if let Some(listener) = handler_metadata.listener.take() {
|
||||
let res = port_handler.start_rejector(
|
||||
port,
|
||||
listener,
|
||||
Packet {
|
||||
header: Header {
|
||||
kind: PacketKind::Reject.raw(),
|
||||
length: 3,
|
||||
},
|
||||
data: b"nc\0".to_vec(),
|
||||
},
|
||||
);
|
||||
|
||||
if let Err(err) = res {
|
||||
println!(
|
||||
"failed to start rejector on port {port} after client error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
318
src/ports.rs
Normal file
318
src/ports.rs
Normal file
@ -0,0 +1,318 @@
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
fs::File,
|
||||
io::{BufReader, BufWriter},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle, time::Instant};
|
||||
|
||||
use crate::{
|
||||
packets::Packet, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize)]
|
||||
pub struct PortHandler {
|
||||
#[serde(skip)]
|
||||
pub last_update: Option<Instant>,
|
||||
|
||||
#[serde(skip)]
|
||||
port_guards: HashMap<Port, Rejector>,
|
||||
|
||||
allowed_ports: AllowedPorts,
|
||||
|
||||
free_ports: HashSet<Port>,
|
||||
errored_ports: BTreeSet<(UnixTimestamp, Port)>,
|
||||
allocated_ports: HashMap<Number, Port>,
|
||||
|
||||
#[serde(skip)]
|
||||
pub port_state: HashMap<Port, PortState>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct PortState {
|
||||
last_change: UnixTimestamp,
|
||||
status: PortStatus,
|
||||
}
|
||||
|
||||
impl PortState {
|
||||
pub fn new_state(&mut self, status: PortStatus) {
|
||||
self.last_change = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
self.status = status;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum PortStatus {
|
||||
Disconnected,
|
||||
Idle,
|
||||
InCall,
|
||||
}
|
||||
|
||||
impl Default for PortStatus {
|
||||
fn default() -> Self {
|
||||
Self::Disconnected
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
pub struct AllowedPorts(Vec<Range<u16>>);
|
||||
|
||||
impl AllowedPorts {
|
||||
pub fn is_allowed(&self, port: Port) -> bool {
|
||||
self.0.iter().any(|range| range.contains(&port))
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl PortHandler {
|
||||
pub fn register_update(&mut self) {
|
||||
self.last_update = Some(Instant::now());
|
||||
}
|
||||
|
||||
pub fn store(&self, cache: &Path) -> anyhow::Result<()> {
|
||||
println!("storing database");
|
||||
serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load(cache: &Path) -> std::io::Result<Self> {
|
||||
println!("loading database");
|
||||
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
|
||||
}
|
||||
|
||||
pub fn load_or_default(cache: &Path) -> Self {
|
||||
Self::load(cache).unwrap_or_else(|err| {
|
||||
println!("failed to parse cache file at {cache:?} using empty cache. error: {err}");
|
||||
Self::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) {
|
||||
self.register_update();
|
||||
|
||||
self.allowed_ports = allowed_ports.clone();
|
||||
|
||||
self.free_ports.clear();
|
||||
self.free_ports
|
||||
.extend(self.allowed_ports.0.iter().cloned().flatten());
|
||||
|
||||
self.free_ports.shrink_to_fit(); // we are at the maximum number of ports we'll ever reach
|
||||
|
||||
self.errored_ports
|
||||
.retain(|(_, port)| self.allowed_ports.is_allowed(*port));
|
||||
|
||||
self.allocated_ports
|
||||
.retain(|_, port| self.allowed_ports.is_allowed(*port));
|
||||
|
||||
self.free_ports.retain(|port| {
|
||||
self.allocated_ports
|
||||
.iter()
|
||||
.find(|(_, allocated_port)| *allocated_port == port)
|
||||
.is_none()
|
||||
&& self
|
||||
.errored_ports
|
||||
.iter()
|
||||
.find(|(_, errored_port)| errored_port == port)
|
||||
.is_none()
|
||||
});
|
||||
}
|
||||
|
||||
pub fn start_rejector(
|
||||
&mut self,
|
||||
port: Port,
|
||||
listener: TcpListener,
|
||||
packet: Packet,
|
||||
) -> anyhow::Result<()> {
|
||||
println!("starting rejector: for port {port} with {packet:?}");
|
||||
|
||||
let port_guard = Rejector::start(listener, packet);
|
||||
|
||||
assert!(
|
||||
self.port_guards.insert(port, port_guard).is_none(),
|
||||
"Tried to start rejector that is already running.
|
||||
This should have been impossible since it requires two listeners on the same port."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> {
|
||||
println!("stopping rejector: for port {port}");
|
||||
|
||||
Some(self.port_guards.remove(&port)?.stop().await)
|
||||
}
|
||||
|
||||
pub async fn change_rejector(
|
||||
&mut self,
|
||||
port: Port,
|
||||
f: impl FnOnce(&mut Packet),
|
||||
) -> anyhow::Result<()> {
|
||||
let (listener, mut packet) = self
|
||||
.stop_rejector(port)
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?;
|
||||
|
||||
f(&mut packet);
|
||||
|
||||
self.start_rejector(port, listener, packet)
|
||||
}
|
||||
}
|
||||
|
||||
struct Rejector {
|
||||
state: Arc<(Mutex<TcpListener>, Packet)>,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Debug for Rejector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PortGuard").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Rejector {
|
||||
fn start(listener: TcpListener, packet: Packet) -> Self {
|
||||
let state = Arc::new((Mutex::new(listener), packet));
|
||||
|
||||
let handle = {
|
||||
let state = state.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (listener, packet) = state.as_ref();
|
||||
|
||||
let listener = listener.lock().await;
|
||||
|
||||
loop {
|
||||
if let Ok((mut socket, _)) = listener.accept().await {
|
||||
let (_, mut writer) = socket.split();
|
||||
let _ = packet.send(&mut writer).await;
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Self { state, handle }
|
||||
}
|
||||
async fn stop(self) -> (TcpListener, Packet) {
|
||||
self.handle.abort();
|
||||
let _ = self.handle.await;
|
||||
let (listener, packet) = Arc::try_unwrap(self.state).unwrap();
|
||||
(listener.into_inner(), packet)
|
||||
}
|
||||
}
|
||||
|
||||
impl PortHandler {
|
||||
pub fn allocate_port_for_number(&mut self, config: &Config, number: Number) -> Option<Port> {
|
||||
if let Some(port) = self.allocated_ports.get(&number) {
|
||||
let already_connected = self
|
||||
.port_state
|
||||
.get(port)
|
||||
.map(|state| state.status != PortStatus::Disconnected)
|
||||
.unwrap_or(false);
|
||||
|
||||
return if already_connected { None } else { Some(*port) };
|
||||
}
|
||||
|
||||
let port = if let Some(&port) = self.free_ports.iter().next() {
|
||||
self.register_update();
|
||||
self.free_ports.remove(&port);
|
||||
port
|
||||
} else {
|
||||
self.try_recover_port(config)?
|
||||
};
|
||||
|
||||
assert!(self.allocated_ports.insert(number, port).is_none());
|
||||
Some(port)
|
||||
}
|
||||
|
||||
fn try_recover_port(&mut self, config: &Config) -> Option<Port> {
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
|
||||
let mut recovered_port = None;
|
||||
|
||||
self.errored_ports = std::mem::take(&mut self.errored_ports)
|
||||
.into_iter()
|
||||
.filter_map(|(mut timestamp, port)| {
|
||||
if recovered_port.is_none()
|
||||
&& now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME
|
||||
{
|
||||
println!(
|
||||
" trying port: {port} at -{:?}",
|
||||
Duration::from_secs(now.as_secs())
|
||||
.saturating_sub(Duration::from_secs(timestamp))
|
||||
);
|
||||
|
||||
match std::net::TcpListener::bind((config.listen_addr.ip(), port)) {
|
||||
Ok(_) => {
|
||||
recovered_port = Some((timestamp, port));
|
||||
return None;
|
||||
}
|
||||
Err(_) => timestamp = now.as_secs(),
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"skipped port: {port} at -{:?}",
|
||||
Duration::from_secs(now.as_secs())
|
||||
.saturating_sub(Duration::from_secs(timestamp))
|
||||
);
|
||||
}
|
||||
|
||||
Some((timestamp, port))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some((_, port)) = recovered_port {
|
||||
self.register_update();
|
||||
println!("recovered_port: {port}");
|
||||
return Some(port);
|
||||
}
|
||||
|
||||
let removable_entry = self.allocated_ports.iter().find(|(_, port)| {
|
||||
self.port_state
|
||||
.get(port)
|
||||
.map(|port_state| {
|
||||
dbg!(port_state).status == PortStatus::Disconnected
|
||||
&& dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change)))
|
||||
>= PORT_OWNERSHIP_TIMEOUT
|
||||
})
|
||||
.unwrap_or(true)
|
||||
});
|
||||
|
||||
dbg!(&removable_entry);
|
||||
|
||||
if let Some((&old_number, &port)) = removable_entry {
|
||||
self.register_update();
|
||||
println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time");
|
||||
assert!(self.allocated_ports.remove(&old_number).is_some());
|
||||
return Some(port);
|
||||
}
|
||||
|
||||
None // TODO
|
||||
}
|
||||
|
||||
pub fn mark_port_error(&mut self, number: Number, port: Port) {
|
||||
self.register_update();
|
||||
|
||||
self.errored_ports.insert((
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
port,
|
||||
));
|
||||
|
||||
self.allocated_ports.remove(&number);
|
||||
self.free_ports.remove(&port);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user