name tasks

This commit is contained in:
2023-03-18 16:27:42 +01:00
parent 931b976a34
commit 9a3bac503c
4 changed files with 694 additions and 30 deletions

View File

@@ -10,6 +10,7 @@ use std::{
use anyhow::{bail, Context};
use debug_server::debug_server;
use futures::Future;
use packets::{Header, Packet, RemConnect};
use serde::{Deserialize, Deserializer};
use tokio::{
@@ -31,6 +32,8 @@ const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60);
const PING_TIMEOUT: Duration = Duration::from_secs(30);
const SEND_PING_INTERVAL: Duration = Duration::from_secs(20);
const CACHE_STORE_INTERVAL: Duration = Duration::from_secs(5);
#[cfg(feature = "debug_server")]
mod debug_server;
mod packets;
@@ -89,8 +92,31 @@ impl Config {
}
}
#[cfg(not(feature = "tokio_console"))]
fn spawn<T: Send + 'static>(
_name: &str,
future: impl Future<Output = T> + Send + 'static,
) -> tokio::task::JoinHandle<T> {
tokio::spawn(future)
}
#[cfg(feature = "tokio_console")]
fn spawn<T: Send + 'static>(
name: &str,
future: impl Future<Output = T> + Send + 'static,
) -> tokio::task::JoinHandle<T> {
tokio::task::Builder::new()
.name(name)
.spawn(future)
.unwrap_or_else(|err| panic!("failed to spawn {name:?}: {err:?}"))
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
#[cfg(feature = "tokio_console")]
console_subscriber::init();
let config = Arc::new(Config::load("config.json")?);
if config.allowed_ports.is_empty() {
@@ -99,31 +125,40 @@ async fn main() -> anyhow::Result<()> {
let cache_path = PathBuf::from("cache.json");
let mut port_handler = PortHandler::load_or_default(&cache_path);
let (change_sender, mut change_receiver) = tokio::sync::watch::channel(Instant::now());
let mut port_handler = PortHandler::load_or_default(&cache_path, change_sender);
port_handler.update_allowed_ports(&config.allowed_ports);
let port_handler = Arc::new(Mutex::new(port_handler));
{
let port_handler = port_handler.clone();
tokio::spawn(async move {
let mut last_store = None;
spawn("cache daemon", async move {
let mut last_store = Instant::now() - 2 * CACHE_STORE_INTERVAL;
let mut change_timeout = None;
loop {
sleep(Duration::from_secs(1)).await;
if let Some(change_timeout) = change_timeout.take() {
tokio::time::timeout(change_timeout, change_receiver.changed())
.await
.ok()
.unwrap_or(Ok(()))
} else {
change_receiver.changed().await
}
.expect("failed to wait for cache changes");
let port_handler = port_handler.lock().await;
let time_since_last_store = last_store.elapsed();
if let Some(last_update) = port_handler.last_update {
let should_store = last_store
.map(|last_store| last_update > last_store)
.unwrap_or(true);
if time_since_last_store > CACHE_STORE_INTERVAL {
let port_handler = port_handler.lock().await;
if should_store {
last_store = Some(last_update);
if let Err(err) = port_handler.store(&cache_path) {
println!("failed to store cache: {err:?}");
}
last_store = Instant::now();
if let Err(err) = port_handler.store(&cache_path) {
println!("failed to store cache: {err:?}");
}
} else {
change_timeout = Some(CACHE_STORE_INTERVAL - time_since_last_store);
}
}
});
@@ -132,7 +167,10 @@ async fn main() -> anyhow::Result<()> {
#[cfg(feature = "debug_server")]
if let Some(debug_server_addr) = config.debug_server_addr {
println!("starting debug server on {debug_server_addr:?}");
tokio::spawn(debug_server(debug_server_addr, port_handler.clone()));
spawn(
"debug server",
debug_server(debug_server_addr, port_handler.clone()),
);
}
let listener = TcpListener::bind(config.listen_addr).await?;
@@ -146,7 +184,7 @@ async fn main() -> anyhow::Result<()> {
let mut handler_metadata = HandlerMetadata::default();
tokio::spawn(async move {
spawn(&format!("connection to {addr}"), async move {
use futures::future::FutureExt;
let res = std::panic::AssertUnwindSafe(connection_handler(

View File

@@ -15,7 +15,8 @@ use serde::{Deserialize, Serialize};
use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle, time::Instant};
use crate::{
packets::Packet, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME,
packets::Packet, spawn, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT,
PORT_RETRY_TIME,
};
#[derive(Default, Serialize, Deserialize)]
@@ -23,6 +24,9 @@ pub struct PortHandler {
#[serde(skip)]
pub last_update: Option<Instant>,
#[serde(skip)]
pub change_sender: Option<tokio::sync::watch::Sender<Instant>>,
#[serde(skip)]
port_guards: HashMap<Port, Rejector>,
@@ -194,7 +198,13 @@ impl PortHandler {
}
pub fn register_update(&mut self) {
self.last_update = Some(Instant::now());
let now = Instant::now();
self.last_update = Some(now);
self.change_sender
.as_ref()
.expect("PortHandler is missing it's change_sender")
.send(now)
.expect("failed to notify cache writer");
}
pub fn store(&self, cache: &Path) -> anyhow::Result<()> {
@@ -207,13 +217,21 @@ impl PortHandler {
Ok(())
}
pub fn load(cache: &Path) -> std::io::Result<Self> {
pub fn load(
cache: &Path,
change_sender: tokio::sync::watch::Sender<Instant>,
) -> std::io::Result<Self> {
println!("loading cache");
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
let mut cache: Self = serde_json::from_reader(BufReader::new(File::open(cache)?))?;
cache.change_sender = Some(change_sender);
Ok(cache)
}
pub fn load_or_default(cache: &Path) -> Self {
Self::load(cache).unwrap_or_else(|err| {
pub fn load_or_default(
cache: &Path,
change_sender: tokio::sync::watch::Sender<Instant>,
) -> Self {
Self::load(cache, change_sender).unwrap_or_else(|err| {
println!("failed to parse cache file at {cache:?} using empty cache. error: {err}");
Self::default()
})
@@ -311,12 +329,13 @@ impl Debug for Rejector {
impl Rejector {
fn start(listener: TcpListener, packet: Packet) -> Self {
let port = listener.local_addr().map(|addr| addr.port()).unwrap_or(0);
let state = Arc::new((Mutex::new(listener), packet));
let handle = {
let state = state.clone();
tokio::spawn(async move {
spawn(&format!("rejector for port {port}",), async move {
let (listener, packet) = state.as_ref();
let listener = listener.lock().await;