use std::{cell::RefCell, collections::HashMap, fmt::Debug, mem::size_of}; use zerocopy::{AsBytes, FromBytes}; use crate::{mapped::ReaderTrait, Db, FilePointer, FileRange, RawFilePointer}; #[derive(Clone, Copy, Debug)] struct Replaced { from: FileRange, to: Option, } pub struct TransactionHandle<'t, R> { db: &'t mut Db, replaced: HashMap, new: HashMap, } impl<'t, R> Debug for TransactionHandle<'t, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TransactionHandle") .field("replaced", &self.replaced) .field("new", &self.new) .finish() } } impl<'t, R> ReaderTrait for TransactionHandle<'t, R> { fn read_raw(&self, ptr: FileRange) -> &[u8] { self.reference_raw(ptr) } } impl<'t, R> TransactionHandle<'t, R> { pub fn new(db: &'t mut Db) -> Self { Self { db, replaced: HashMap::new(), new: HashMap::new(), } } pub(crate) fn to_free(&self) -> Vec { self.replaced .values() .map(|replaced| replaced.from) .collect() } fn read_ptr(&self, src: FilePointer) -> FilePointer { FilePointer::from_range(self.read_ptr_raw(src.range())) } unsafe fn write_ptr(&mut self, src: FilePointer) -> FilePointer { FilePointer::from_range(self.write_ptr_raw(src.range())) } fn read_ptr_raw(&self, range: FileRange) -> FileRange { if let Some(&replaced) = self.replaced.get(&range.start) { assert_eq!(replaced.from, range); // TODO: replacing this with unwrap_or(range) // will access the original region, which can't have actually been // allocated, but was logically freed. replaced.to.unwrap_or_else(|| { dbg!(&self); panic!("use after free at: {range:?}") }) } else if let Some(&new) = self.new.get(&range.start) { debug_assert_eq!(new.start, range.start); assert!(new.end() >= range.end()); new.start.range(range.len()) } else { range } } unsafe fn write_ptr_raw(&mut self, range: FileRange) -> FileRange { let new_range = if let Some(&replaced) = self.replaced.get(&range.start) { assert_eq!(replaced.from, range); replaced.to } else if let Some(&new) = self.new.get(&range.start) { assert_eq!(new, range); Some(new) } else { None }; if let Some(range) = new_range { range } else { let (new, _) = self.allocate_range(range.len()); self.db.copy_nonoverlapping(range, new); let res = self.replaced.insert( range.start, Replaced { from: range, to: Some(new), }, ); // TODO: removing this allows freeing a region and // then implicitely reallocating it by writing to it. // Should that be allowed? assert!(res.is_none(), "use after free"); new } } pub fn reference_raw(&self, range: FileRange) -> &[u8] { let range = self.read_ptr_raw(range); &self.db.map[range.as_range()] } pub fn modify_raw(&mut self, range: FileRange) -> (FileRange, &mut [u8]) { let range = unsafe { self.write_ptr_raw(range) }; (range, &mut self.db.map[range.as_range()]) } pub fn modify_range( &mut self, range: FileRange, ) -> (FileRange, &mut T) { unsafe { let (ptr, _) = self.modify_raw(range); (ptr, self.db.modify_range(ptr)) } } pub fn modify( &mut self, at: FilePointer, ) -> (FilePointer, &mut T) { let (range, data) = self.modify_range(at.range()); (FilePointer::from_range(range), data) } pub fn set(&mut self, at: FilePointer, value: T) -> FilePointer { let (ptr, data) = self.modify(at); *data = value; ptr } pub fn allocate_range(&mut self, length: u64) -> (FileRange, &mut [u8]) { unsafe { let range = self.db.allocate(length); let res = self.new.insert(range.start, range); debug_assert!(res.is_none()); (range, &mut self.db.map[range.as_range()]) } } pub fn allocate(&mut self) -> (FilePointer, &mut T) { unsafe { let (range, _) = self.allocate_range(size_of::() as u64); (FilePointer::from_range(range), self.db.modify_range(range)) } } pub fn free(&mut self, at: FilePointer) { self.free_range(at.range()) } pub fn free_range(&mut self, range: FileRange) { let mut freed = false; if let Some(allocation) = self.new.remove(&range.start) { assert_eq!(allocation, range); self.db.free(range); freed = true; } for replaced in self.replaced.values_mut() { if replaced.from == range || replaced.to == Some(range) { replaced.to = None; freed = true; } } if !freed { let res = self.replaced.insert( range.start, Replaced { from: range, to: None, }, ); assert!(res.is_none()); } } pub fn root(&self) -> FilePointer { self.db.root() } }