name tasks
This commit is contained in:
70
src/main.rs
70
src/main.rs
@@ -10,6 +10,7 @@ use std::{
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use debug_server::debug_server;
|
||||
use futures::Future;
|
||||
use packets::{Header, Packet, RemConnect};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use tokio::{
|
||||
@@ -31,6 +32,8 @@ const PORT_OWNERSHIP_TIMEOUT: Duration = Duration::from_secs(1 * 60 * 60);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const SEND_PING_INTERVAL: Duration = Duration::from_secs(20);
|
||||
|
||||
const CACHE_STORE_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
#[cfg(feature = "debug_server")]
|
||||
mod debug_server;
|
||||
mod packets;
|
||||
@@ -89,8 +92,31 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tokio_console"))]
|
||||
|
||||
fn spawn<T: Send + 'static>(
|
||||
_name: &str,
|
||||
future: impl Future<Output = T> + Send + 'static,
|
||||
) -> tokio::task::JoinHandle<T> {
|
||||
tokio::spawn(future)
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio_console")]
|
||||
fn spawn<T: Send + 'static>(
|
||||
name: &str,
|
||||
future: impl Future<Output = T> + Send + 'static,
|
||||
) -> tokio::task::JoinHandle<T> {
|
||||
tokio::task::Builder::new()
|
||||
.name(name)
|
||||
.spawn(future)
|
||||
.unwrap_or_else(|err| panic!("failed to spawn {name:?}: {err:?}"))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
#[cfg(feature = "tokio_console")]
|
||||
console_subscriber::init();
|
||||
|
||||
let config = Arc::new(Config::load("config.json")?);
|
||||
|
||||
if config.allowed_ports.is_empty() {
|
||||
@@ -99,31 +125,40 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let cache_path = PathBuf::from("cache.json");
|
||||
|
||||
let mut port_handler = PortHandler::load_or_default(&cache_path);
|
||||
let (change_sender, mut change_receiver) = tokio::sync::watch::channel(Instant::now());
|
||||
|
||||
let mut port_handler = PortHandler::load_or_default(&cache_path, change_sender);
|
||||
port_handler.update_allowed_ports(&config.allowed_ports);
|
||||
|
||||
let port_handler = Arc::new(Mutex::new(port_handler));
|
||||
|
||||
{
|
||||
let port_handler = port_handler.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut last_store = None;
|
||||
spawn("cache daemon", async move {
|
||||
let mut last_store = Instant::now() - 2 * CACHE_STORE_INTERVAL;
|
||||
let mut change_timeout = None;
|
||||
loop {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
if let Some(change_timeout) = change_timeout.take() {
|
||||
tokio::time::timeout(change_timeout, change_receiver.changed())
|
||||
.await
|
||||
.ok()
|
||||
.unwrap_or(Ok(()))
|
||||
} else {
|
||||
change_receiver.changed().await
|
||||
}
|
||||
.expect("failed to wait for cache changes");
|
||||
|
||||
let port_handler = port_handler.lock().await;
|
||||
let time_since_last_store = last_store.elapsed();
|
||||
|
||||
if let Some(last_update) = port_handler.last_update {
|
||||
let should_store = last_store
|
||||
.map(|last_store| last_update > last_store)
|
||||
.unwrap_or(true);
|
||||
if time_since_last_store > CACHE_STORE_INTERVAL {
|
||||
let port_handler = port_handler.lock().await;
|
||||
|
||||
if should_store {
|
||||
last_store = Some(last_update);
|
||||
if let Err(err) = port_handler.store(&cache_path) {
|
||||
println!("failed to store cache: {err:?}");
|
||||
}
|
||||
last_store = Instant::now();
|
||||
if let Err(err) = port_handler.store(&cache_path) {
|
||||
println!("failed to store cache: {err:?}");
|
||||
}
|
||||
} else {
|
||||
change_timeout = Some(CACHE_STORE_INTERVAL - time_since_last_store);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -132,7 +167,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
#[cfg(feature = "debug_server")]
|
||||
if let Some(debug_server_addr) = config.debug_server_addr {
|
||||
println!("starting debug server on {debug_server_addr:?}");
|
||||
tokio::spawn(debug_server(debug_server_addr, port_handler.clone()));
|
||||
spawn(
|
||||
"debug server",
|
||||
debug_server(debug_server_addr, port_handler.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind(config.listen_addr).await?;
|
||||
@@ -146,7 +184,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let mut handler_metadata = HandlerMetadata::default();
|
||||
|
||||
tokio::spawn(async move {
|
||||
spawn(&format!("connection to {addr}"), async move {
|
||||
use futures::future::FutureExt;
|
||||
|
||||
let res = std::panic::AssertUnwindSafe(connection_handler(
|
||||
|
||||
33
src/ports.rs
33
src/ports.rs
@@ -15,7 +15,8 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle, time::Instant};
|
||||
|
||||
use crate::{
|
||||
packets::Packet, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT, PORT_RETRY_TIME,
|
||||
packets::Packet, spawn, Config, Number, Port, UnixTimestamp, PORT_OWNERSHIP_TIMEOUT,
|
||||
PORT_RETRY_TIME,
|
||||
};
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
@@ -23,6 +24,9 @@ pub struct PortHandler {
|
||||
#[serde(skip)]
|
||||
pub last_update: Option<Instant>,
|
||||
|
||||
#[serde(skip)]
|
||||
pub change_sender: Option<tokio::sync::watch::Sender<Instant>>,
|
||||
|
||||
#[serde(skip)]
|
||||
port_guards: HashMap<Port, Rejector>,
|
||||
|
||||
@@ -194,7 +198,13 @@ impl PortHandler {
|
||||
}
|
||||
|
||||
pub fn register_update(&mut self) {
|
||||
self.last_update = Some(Instant::now());
|
||||
let now = Instant::now();
|
||||
self.last_update = Some(now);
|
||||
self.change_sender
|
||||
.as_ref()
|
||||
.expect("PortHandler is missing it's change_sender")
|
||||
.send(now)
|
||||
.expect("failed to notify cache writer");
|
||||
}
|
||||
|
||||
pub fn store(&self, cache: &Path) -> anyhow::Result<()> {
|
||||
@@ -207,13 +217,21 @@ impl PortHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load(cache: &Path) -> std::io::Result<Self> {
|
||||
pub fn load(
|
||||
cache: &Path,
|
||||
change_sender: tokio::sync::watch::Sender<Instant>,
|
||||
) -> std::io::Result<Self> {
|
||||
println!("loading cache");
|
||||
Ok(serde_json::from_reader(BufReader::new(File::open(cache)?))?)
|
||||
let mut cache: Self = serde_json::from_reader(BufReader::new(File::open(cache)?))?;
|
||||
cache.change_sender = Some(change_sender);
|
||||
Ok(cache)
|
||||
}
|
||||
|
||||
pub fn load_or_default(cache: &Path) -> Self {
|
||||
Self::load(cache).unwrap_or_else(|err| {
|
||||
pub fn load_or_default(
|
||||
cache: &Path,
|
||||
change_sender: tokio::sync::watch::Sender<Instant>,
|
||||
) -> Self {
|
||||
Self::load(cache, change_sender).unwrap_or_else(|err| {
|
||||
println!("failed to parse cache file at {cache:?} using empty cache. error: {err}");
|
||||
Self::default()
|
||||
})
|
||||
@@ -311,12 +329,13 @@ impl Debug for Rejector {
|
||||
|
||||
impl Rejector {
|
||||
fn start(listener: TcpListener, packet: Packet) -> Self {
|
||||
let port = listener.local_addr().map(|addr| addr.port()).unwrap_or(0);
|
||||
let state = Arc::new((Mutex::new(listener), packet));
|
||||
|
||||
let handle = {
|
||||
let state = state.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
spawn(&format!("rejector for port {port}",), async move {
|
||||
let (listener, packet) = state.as_ref();
|
||||
|
||||
let listener = listener.lock().await;
|
||||
|
||||
Reference in New Issue
Block a user