added most features

This commit is contained in:
soruh 2023-01-25 22:04:53 +01:00
parent 3163644c62
commit 6a3324563c
4 changed files with 398 additions and 160 deletions

2
.gitignore vendored
View File

@ -1,2 +1,2 @@
/target
db.json
cache.json

8
config.json Normal file
View File

@ -0,0 +1,8 @@
{
"allowed_ports": [
[
3000,
3005
]
]
}

View File

@ -1,38 +1,38 @@
#![feature(generic_const_exprs)]
#![allow(unused)]
// #![allow(unused)]
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
collections::{BTreeSet, HashMap, HashSet},
fmt::Debug,
fs::File,
future::Future,
io::{BufReader, BufWriter},
net::{IpAddr, Ipv4Addr, SocketAddr},
net::{IpAddr, SocketAddr},
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Mutex},
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::bail;
use packets::{reject_static, Header, Packet, RemConnect};
use anyhow::{anyhow, bail};
use packets::{Header, Packet, RemConnect};
use serde::{Deserialize, Serialize};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpSocket, TcpStream},
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
select,
sync::Mutex,
task::JoinHandle,
time::Instant,
time::{sleep, Instant},
};
use crate::packets::dyn_ip_update;
use crate::packets::{dyn_ip_update, PacketKind, REJECT_OOP, REJECT_TIMEOUT};
const AUTH_TIMEOUT: Duration = Duration::from_secs(30);
const CALL_ACK_TIMEOUT: Duration = Duration::from_secs(30);
const PING_INTERVAL: Duration = Duration::from_secs(15);
const TIMEOUT_DELAY: Duration = Duration::from_secs(35);
const PORT_TIMEOUT: Duration = Duration::from_secs(60);
const PORT_RETRY_TIME: Duration = Duration::from_secs(60); // 10 *
const CALL_TIMEOUT: Duration = Duration::from_secs(24 * 60 * 60);
const PORT_RETRY_TIME: Duration = Duration::from_secs(15 * 60);
const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60);
const PING_TIMEOUT: Duration = Duration::from_secs(30);
const SEND_PING_INTERVAL: Duration = Duration::from_secs(30);
const BIND_IP: &str = "0.0.0.0";
@ -48,14 +48,14 @@ struct Config {
}
impl Config {
fn load(db: &Path) -> std::io::Result<Self> {
fn load(cache: &Path) -> std::io::Result<Self> {
println!("loading config");
Ok(serde_json::from_reader(BufReader::new(File::open(db)?))?)
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
}
fn load_or_default(db: &Path) -> std::io::Result<Self> {
match Self::load(db) {
Ok(db) => Ok(db),
fn load_or_default(cache: &Path) -> std::io::Result<Self> {
match Self::load(cache) {
Ok(cache) => Ok(cache),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(Self::default()),
Err(err) => Err(err),
}
@ -75,11 +75,40 @@ struct PortHandler {
free_ports: HashSet<Port>,
errored_ports: BTreeSet<(UnixTimestamp, Port)>,
allocated_ports: HashMap<Number, Port>,
port_status: HashMap<Port, PortStatus>,
#[serde(skip)]
port_state: HashMap<Port, PortState>,
}
#[derive(Debug, Serialize, Deserialize)]
struct PortStatus {}
#[derive(Default, Debug)]
struct PortState {
last_change: UnixTimestamp,
status: PortStatus,
}
impl PortState {
fn new_state(&mut self, status: PortStatus) {
self.last_change = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.status = status;
}
}
#[derive(Debug, PartialEq, Eq)]
enum PortStatus {
Disconnected,
Idle,
InCall,
}
impl Default for PortStatus {
fn default() -> Self {
Self::Disconnected
}
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
struct AllowedPorts(Vec<Range<u16>>);
@ -95,23 +124,19 @@ impl PortHandler {
self.last_update = Some(Instant::now());
}
fn store(&self, db: &Path) -> anyhow::Result<()> {
fn store(&self, cache: &Path) -> anyhow::Result<()> {
println!("storing database");
serde_json::to_writer(BufWriter::new(File::create(db)?), self)?;
serde_json::to_writer(BufWriter::new(File::create(cache)?), self)?;
Ok(())
}
fn load(db: &Path) -> std::io::Result<Self> {
fn load(cache: &Path) -> std::io::Result<Self> {
println!("loading database");
Ok(serde_json::from_reader(BufReader::new(File::open(db)?))?)
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
}
fn load_or_default(db: &Path) -> std::io::Result<Self> {
match Self::load(db) {
Ok(db) => Ok(db),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(Self::default()),
Err(err) => Err(err),
}
fn load_or_default(cache: &Path) -> Self {
Self::load(cache).unwrap_or(Self::default())
}
fn update_allowed_ports(&mut self, allowed_ports: &AllowedPorts) {
@ -144,37 +169,48 @@ impl PortHandler {
});
}
fn start_port_guard<'fut, Fut, Func>(&mut self, port: Port, listener: TcpListener, f: Func)
where
Fut: Future<Output = ()> + Send + 'fut,
Func: FnOnce(&'_ mut TcpListener) -> Fut + Send + 'static,
{
assert!(self
.port_guards
.insert(port, PortGuard::start(listener, f))
.is_none());
fn start_rejector(
&mut self,
port: Port,
listener: TcpListener,
packet: Packet,
) -> anyhow::Result<()> {
println!("starting rejector: for port {port} with packet {packet:?}");
let port_guard = PortGuard::start(listener, packet);
assert!(
self.port_guards.insert(port, port_guard).is_none(),
"Tried to start rejector that is already running.
This should have been impossible since it requires two listeners on the same port."
);
Ok(())
}
fn start_rejector(&mut self, port: Port, listener: TcpListener, packet: Packet) {
assert!(self
.port_guards
.insert(
port,
PortGuard::start(listener, move |listener: &mut TcpListener| async move {
loop {
if let Ok((mut socket, _)) = listener.accept().await {
let (_, mut writer) = socket.split();
let _ = packet.send(&mut writer).await;
}
}
})
)
.is_none());
async fn stop_rejector(&mut self, port: Port) -> Option<(TcpListener, Packet)> {
println!("stopping rejector: for port {port}");
Some(self.port_guards.remove(&port)?.stop().await)
}
async fn change_rejector(
&mut self,
port: Port,
f: impl FnOnce(&mut Packet),
) -> anyhow::Result<()> {
let (listener, mut packet) = self
.stop_rejector(port)
.await
.ok_or_else(|| anyhow!("tried to stop rejector that is not running"))?;
f(&mut packet);
self.start_rejector(port, listener, packet)
}
}
struct PortGuard {
listener: Arc<tokio::sync::Mutex<TcpListener>>,
state: Arc<(Mutex<TcpListener>, Packet)>,
handle: JoinHandle<()>,
}
@ -185,38 +221,46 @@ impl Debug for PortGuard {
}
impl PortGuard {
fn start<'fut, Fut>(
listener: TcpListener,
f: impl FnOnce(&mut TcpListener) -> Fut + Send + 'static,
) -> Self
where
Fut: Future<Output = ()> + Send + 'fut,
{
let mut listener = Arc::new(tokio::sync::Mutex::new(listener));
fn start(listener: TcpListener, packet: Packet) -> Self {
let state = Arc::new((Mutex::new(listener), packet));
let handle = {
let listener = listener.clone();
let state = state.clone();
tokio::spawn(async move {
let mut lock = listener.lock().await;
f(&mut *lock).await;
let (listener, packet) = state.as_ref();
let listener = listener.lock().await;
loop {
if let Ok((mut socket, _)) = listener.accept().await {
let (_, mut writer) = socket.split();
let _ = packet.send(&mut writer).await;
}
}
})
};
Self { listener, handle }
Self { state, handle }
}
async fn stop(mut self) -> TcpListener {
async fn stop(self) -> (TcpListener, Packet) {
self.handle.abort();
let _ = self.handle.await;
Arc::try_unwrap(self.listener).unwrap().into_inner()
let (listener, packet) = Arc::try_unwrap(self.state).unwrap();
(listener.into_inner(), packet)
}
}
impl PortHandler {
fn allocate_port_for_number(&mut self, number: Number) -> Option<Port> {
if let Some(port) = self.allocated_ports.get(&number) {
return Some(*port);
let already_connected = self
.port_state
.get(port)
.map(|state| state.status != PortStatus::Disconnected)
.unwrap_or(false);
return if already_connected { None } else { Some(*port) };
}
let port = if let Some(&port) = self.free_ports.iter().next() {
@ -238,7 +282,7 @@ impl PortHandler {
self.errored_ports = std::mem::take(&mut self.errored_ports)
.into_iter()
.filter_map(|(mut timestamp, mut port)| {
.filter_map(|(mut timestamp, port)| {
if recovered_port.is_none()
&& now.saturating_sub(Duration::from_secs(timestamp)) >= PORT_RETRY_TIME
{
@ -268,10 +312,31 @@ impl PortHandler {
.collect();
if let Some((_, port)) = recovered_port {
self.register_update();
println!("recovered_port: {port}");
return Some(port);
}
let removable_entry = self.allocated_ports.iter().find(|(_, port)| {
self.port_state
.get(port)
.map(|port_state| {
dbg!(port_state).status == PortStatus::Disconnected
&& dbg!(now.saturating_sub(Duration::from_secs(port_state.last_change)))
>= PORT_OWNERSHIP_TIMEOUT
})
.unwrap_or(true)
});
dbg!(&removable_entry);
if let Some((&old_number, &port)) = removable_entry {
self.register_update();
println!("reused port {port} which used to be allocated to {old_number} which wasn't connected in a long time");
assert!(self.allocated_ports.remove(&old_number).is_some());
return Some(port);
}
None // TODO
}
@ -289,92 +354,241 @@ impl PortHandler {
self.allocated_ports.remove(&number);
self.free_ports.remove(&port);
}
}
fn open_port(&mut self, port: Port) -> Option<TcpListener> {
todo!()
}
fn close_port_for(&mut self, number: Number, listener: TcpListener) -> anyhow::Result<()> {
todo!()
}
#[derive(Debug, Default)]
struct HandlerMetadata {
number: Option<Number>,
port: Option<Port>,
}
async fn connection_handler(
port_handler: Arc<Mutex<PortHandler>>,
handler_metadata: &mut HandlerMetadata,
port_handler: &Mutex<PortHandler>,
stream: &mut TcpStream,
) -> anyhow::Result<()> {
let (mut reader, mut writer) = stream.split();
let mut packet = Packet::recv(&mut reader).await?;
let mut packet = Packet::default();
select! {
res = packet.recv_into_cancelation_safe(&mut reader) => res?,
_ = sleep(AUTH_TIMEOUT) => {
writer.write_all(REJECT_TIMEOUT).await?;
return Ok(());
}
}
let RemConnect { number, pin } = packet.as_rem_connect()?;
handler_metadata.number = Some(number);
let mut authenticated = false;
let (port, listener) = loop {
let port = port_handler
.lock()
.unwrap()
.allocate_port_for_number(number);
let mut updated_server = false;
let port = port_handler.lock().await.allocate_port_for_number(number);
println!("allocated port: {:?}", port);
let Some(port) = port else {
writer.write_all(&reject_static(b"oop")).await?;
writer.write_all(REJECT_OOP).await?;
return Ok(());
};
let ip = dyn_ip_update(number, pin, port).await?;
if !authenticated {
let _ip = dyn_ip_update(number, pin, port).await?;
authenticated = true;
updated_server = true;
}
let listener = TcpListener::bind((BIND_IP, port)).await;
let mut port_handler = port_handler.lock().await;
let listener = match listener {
Ok(listener) => break (port, listener),
Err(err) => {
port_handler.lock().unwrap().mark_port_error(number, port);
// tokio::time::sleep(Duration::from_millis(300)).await;
let listener = if let Some((listener, _package)) = port_handler.stop_rejector(port).await {
Ok(listener)
} else {
TcpListener::bind((BIND_IP, port)).await
};
match listener {
Ok(listener) => {
if !updated_server {
let _ip = dyn_ip_update(number, pin, port).await?;
}
port_handler
.port_state
.entry(port)
.or_default()
.new_state(PortStatus::Idle);
handler_metadata.port = Some(port);
break (port, listener);
}
Err(_err) => {
port_handler.mark_port_error(number, port);
continue;
}
};
};
#[derive(Debug)]
enum Foo {
Caller { stream: TcpStream, addr: SocketAddr },
Packet { packet: Packet },
enum Result {
Caller {
packet: Packet,
stream: TcpStream,
addr: SocketAddr,
},
Packet {
packet: Packet,
},
}
let result = select! {
kind = Packet::peek_packet_kind(&mut reader) => {
packet.recv_into(&mut reader).await?;
Foo::Packet { packet }
},
caller = listener.accept() => {
let (stream, addr) = caller?;
Foo::Caller { stream, addr }
},
let mut last_ping_sent_at = Instant::now();
let mut last_ping_received_at = Instant::now();
let result = loop {
let now = Instant::now();
select! {
caller = listener.accept() => {
let (stream, addr) = caller?;
break Result::Caller { packet, stream, addr }
},
_ = Packet::peek_packet_kind(&mut reader) => {
packet.recv_into(&mut reader).await?;
if packet.kind() == PacketKind::Ping {
last_ping_received_at = now;
} else {
break Result::Packet { packet }
}
},
_ = sleep(now.saturating_duration_since(last_ping_sent_at).saturating_sub(SEND_PING_INTERVAL)) => {
writer.write_all(bytemuck::bytes_of(& Header { kind: PacketKind::Ping.raw(), length: 0 })).await?;
last_ping_sent_at = now;
}
_ = sleep(now.saturating_duration_since(last_ping_received_at).saturating_sub(PING_TIMEOUT)) => {
writer.write_all(REJECT_TIMEOUT).await?;
return Ok(());
}
}
};
dbg!(&result);
let (mut client, mut packet) = match result {
Result::Packet { packet } => {
if matches!(
packet.kind(),
packets::PacketKind::End | packets::PacketKind::Reject
) {
println!("got disconnect packet: {packet:?}");
match result {
Foo::Caller { stream, addr } => todo!(),
Foo::Packet { mut packet } => {
match packet.kind() {
packets::PacketKind::End => {
packet.header = Header { kind: 3, length: 0 };
packet.data.clear();
}
packets::PacketKind::Reject => {}
kind => bail!("unexpected packet: {kind:?}"),
port_handler
.lock()
.await
.start_rejector(port, listener, packet)?;
return Ok(());
} else {
bail!("unexpected packet: {:?}", packet.kind())
}
port_handler
.lock()
.unwrap()
.start_rejector(port, listener, packet);
}
Result::Caller {
mut packet,
stream,
addr,
} => {
println!("got caller from: {addr}");
packet.data.clear();
match addr.ip() {
IpAddr::V4(addr) => packet.data.extend_from_slice(&addr.octets()),
IpAddr::V6(addr) => packet.data.extend_from_slice(&addr.octets()),
}
packet.header = Header {
kind: PacketKind::RemCall.raw(),
length: packet.data.len() as u8,
};
packet.send(&mut writer).await?;
(stream, packet)
}
};
select! {
res = packet.recv_into_cancelation_safe(&mut reader) => res?,
_ = sleep(CALL_ACK_TIMEOUT) => {
writer.write_all(REJECT_TIMEOUT).await?;
return Ok(());
}
}
Ok(())
match packet.kind() {
PacketKind::End | PacketKind::Reject => {
port_handler
.lock()
.await
.start_rejector(port, listener, packet)?;
return Ok(());
}
PacketKind::RemAck => {
packet.header = Header {
kind: PacketKind::Reject.raw(),
length: 4,
};
packet.data.clear();
packet.data.extend_from_slice(b"occ");
packet.data.push(0);
{
let mut port_handler = port_handler.lock().await;
port_handler.register_update();
port_handler
.port_state
.entry(port)
.or_default()
.new_state(PortStatus::InCall);
port_handler.start_rejector(port, listener, packet)?;
}
select! {
_ = tokio::io::copy_bidirectional(stream, &mut client) => {}
_ = sleep(CALL_TIMEOUT) => {}
}
{
let mut port_handler = port_handler.lock().await;
port_handler.register_update();
port_handler
.port_state
.entry(port)
.or_default()
.new_state(PortStatus::Disconnected);
port_handler
.change_rejector(port, |packet| {
packet.data.clear();
packet.data.extend_from_slice(b"nc");
packet.data.push(0);
packet.header = Header {
kind: PacketKind::Reject.raw(),
length: packet.data.len() as u8,
};
})
.await?;
}
return Ok(());
}
kind => bail!("unexpected packet: {:?}", kind),
}
}
#[tokio::main]
@ -385,9 +599,9 @@ async fn main() -> anyhow::Result<()> {
panic!("no allowed ports");
}
let db_path = PathBuf::from("db.json");
let cache_path = PathBuf::from("cache.json");
let mut port_handler = PortHandler::load_or_default(&db_path)?;
let mut port_handler = PortHandler::load_or_default(&cache_path);
port_handler.update_allowed_ports(&config.allowed_ports);
let port_handler = Arc::new(Mutex::new(port_handler));
@ -397,9 +611,9 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(async move {
let mut last_store = None;
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
sleep(Duration::from_secs(1)).await;
let port_handler = port_handler.lock().unwrap();
let port_handler = port_handler.lock().await;
if let Some(last_update) = port_handler.last_update {
let should_store = last_store
@ -408,22 +622,26 @@ async fn main() -> anyhow::Result<()> {
if should_store {
last_store = Some(last_update);
port_handler.store(&db_path).unwrap();
port_handler.store(&cache_path).unwrap();
}
}
}
});
}
let listener = TcpListener::bind(("127.0.0.1", 11812)).await?;
let listener = TcpListener::bind(("0.0.0.0", 11820)).await?;
while let Ok((mut stream, addr)) = listener.accept().await {
println!("connection from {addr}");
let port_handler = port_handler.clone();
let mut handler_metadata = HandlerMetadata::default();
tokio::spawn(async move {
if let Err(err) = connection_handler(port_handler, &mut stream).await {
let res = connection_handler(&mut handler_metadata, &port_handler, &mut stream).await;
if let Err(err) = res {
println!("client at {addr} had an error: {err}");
let mut packet = Packet::default();
@ -432,13 +650,26 @@ async fn main() -> anyhow::Result<()> {
packet.data.truncate(0xfe);
packet.data.push(0);
packet.header = Header {
kind: 0xff,
kind: PacketKind::Error.raw(),
length: packet.data.len() as u8,
};
let (_, mut writer) = stream.split();
let _ = packet.send(&mut writer).await;
}
// if let Some(number) = handler_metadata.number {
//
// }
if let Some(port) = handler_metadata.port {
if let Some(port_state) = port_handler.lock().await.port_state.get_mut(&port) {
port_state.new_state(PortStatus::Disconnected);
}
}
sleep(Duration::from_secs(3)).await;
let _ = stream.shutdown().await;
});
}

View File

@ -1,5 +1,3 @@
use std::{ffi::CString, mem::discriminant};
use anyhow::bail;
use bytemuck::{Pod, Zeroable};
use tokio::{
@ -7,25 +5,13 @@ use tokio::{
net::tcp::{ReadHalf, WriteHalf},
};
pub const fn reject_static<const N: usize>(message: &[u8; N]) -> [u8; N + 2] {
let mut pkg = [0u8; N + 2];
pkg[0] = 4;
pkg[1] = message.len() as u8;
let mut i = 0;
while i < message.len() {
pkg[i + 2] = message[i];
i += 1;
}
pkg
}
pub const REJECT_OCC: &[u8; 6] = b"\x04\x04occ\x00";
pub const REJECT_NC: &[u8; 5] = b"\x04\x03nc\x00";
pub const REJECT_OOP: &[u8; 6] = b"\x04\x04oop\x00";
pub const REJECT_TIMEOUT: &[u8; 10] = b"\x04\x08timeout\x00";
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketKind {
Unknown(u8),
Ping = 0x00,
DynIpUpdate = 0x01,
DynIpUpdateResponse = 0x02,
End = 0x03,
@ -34,6 +20,7 @@ pub enum PacketKind {
RemConfirm = 0x82,
RemCall = 0x83,
RemAck = 0x84,
Unknown(u8),
Error = 0xff,
}
@ -42,6 +29,7 @@ impl PacketKind {
use PacketKind::*;
match raw {
0x00 => Ping,
0x01 => DynIpUpdate,
0x02 => DynIpUpdateResponse,
0x03 => End,
@ -55,11 +43,11 @@ impl PacketKind {
}
}
fn kind(&self) -> u8 {
pub fn raw(&self) -> u8 {
use PacketKind::*;
match self {
Unknown(value) => *value,
Ping => 0,
DynIpUpdate => 0x01,
DynIpUpdateResponse => 0x02,
End => 0x03,
@ -69,6 +57,8 @@ impl PacketKind {
RemCall => 0x83,
RemAck => 0x84,
Error => 0xff,
Unknown(value) => *value,
}
}
}
@ -111,10 +101,18 @@ impl Packet {
}
}
pub async fn recv(stream: &mut ReadHalf<'_>) -> std::io::Result<Packet> {
let mut packet = Packet::default();
packet.recv_into(stream).await?;
Ok(packet)
pub async fn recv_into_cancelation_safe(
&mut self,
stream: &mut ReadHalf<'_>,
) -> std::io::Result<()> {
// Makes sure all data is available before reading
let header_bytes = bytemuck::bytes_of_mut(&mut self.header);
stream.peek(header_bytes).await?;
self.data.resize(self.header.length as usize + 2, 0);
stream.peek(&mut self.data).await?;
// All data is available. Read the data
self.recv_into(stream).await
}
pub async fn recv_into(&mut self, stream: &mut ReadHalf<'_>) -> std::io::Result<()> {
@ -161,7 +159,7 @@ impl Packet {
pub async fn dyn_ip_update(number: u32, pin: u16, port: u16) -> anyhow::Result<std::net::Ipv4Addr> {
let mut packet = Packet::default();
packet.header = Header {
kind: PacketKind::DynIpUpdate.kind(),
kind: PacketKind::DynIpUpdate.raw(),
length: 8,
};
@ -171,7 +169,8 @@ pub async fn dyn_ip_update(number: u32, pin: u16, port: u16) -> anyhow::Result<s
packet.data.extend_from_slice(&pin.to_le_bytes());
packet.data.extend_from_slice(&port.to_le_bytes());
let mut socket = tokio::net::TcpStream::connect(("127.0.0.1", 11811)).await?;
let mut socket = tokio::net::TcpStream::connect(("tlnserv.teleprinter.net", 11811)).await?;
// 127.0.0.1
let (mut reader, mut writer) = socket.split();