diff --git a/src/transaction.rs b/src/transaction.rs new file mode 100644 index 0000000..7a93b00 --- /dev/null +++ b/src/transaction.rs @@ -0,0 +1,170 @@ +use std::{cell::RefCell, collections::HashMap, mem::size_of}; + +use zerocopy::{AsBytes, FromBytes}; + +use crate::{mapped::ReaderTrait, Db, FilePointer, FileRange}; + +#[derive(Clone, Copy)] +struct Replaced { + from: FileRange, + to: Option, +} + +pub struct TransactionHandle<'t> { + db: &'t mut Db, + replaced: HashMap, + new: HashMap, +} + +impl<'t> ReaderTrait for TransactionHandle<'t> { + fn read_raw(&self, ptr: FileRange) -> &[u8] { + self.reference_raw(ptr) + } +} +impl<'t> TransactionHandle<'t> { + pub fn new(db: &'t mut Db) -> Self { + Self { + db, + replaced: HashMap::new(), + new: HashMap::new(), + } + } + + pub fn to_free(&self) -> Vec { + self.replaced + .values() + .map(|replaced| replaced.from) + .collect() + } + + pub fn read_ptr(&self, range: FileRange) -> FileRange { + if let Some(&replaced) = self.replaced.get(&range.start) { + assert_eq!(replaced.from, range); + replaced.to.expect("use after free") + } else if let Some(&new) = self.new.get(&range.start) { + assert_eq!(new, range); + new + } else { + range + } + } + + pub unsafe fn write_ptr(&mut self, range: FileRange) -> FileRange { + let new_range = if let Some(&replaced) = self.replaced.get(&range.start) { + assert_eq!(replaced.from, range); + replaced.to + } else if let Some(&new) = self.new.get(&range.start) { + assert_eq!(new, range); + Some(new) + } else { + None + }; + + if let Some(range) = new_range { + range + } else { + let (new, _) = self.allocate_raw(range.len()); + + self.db.copy(range, new); + + let res = self.replaced.insert( + new.start, + Replaced { + from: range, + to: Some(new), + }, + ); + // TODO: removing this allows freeing a region and + // then implicitely reallocating it by writing to it. + // Should that be allowed? + assert!(res.is_none(), "use after free"); + + new + } + } + + pub fn reference_raw(&self, range: FileRange) -> &[u8] { + let range = self.read_ptr(range); + &self.db.map[range.as_range()] + } + + pub unsafe fn modify_raw(&mut self, range: FileRange) -> (FileRange, &mut [u8]) { + let range = self.write_ptr(range); + (range, &mut self.db.map[range.as_range()]) + } + + pub fn allocate_raw(&mut self, length: u64) -> (FileRange, &mut [u8]) { + unsafe { + let range = self.allocate_range(length); + + let res = self.new.insert(range.start, range); + debug_assert!(res.is_none()); + + (range, &mut self.db.map[range.as_range()]) + } + } + + pub fn modify_range( + &mut self, + range: FileRange, + ) -> (FileRange, &mut T) { + unsafe { + let (ptr, _) = self.modify_raw(range); + (ptr, self.db.modify_range(ptr)) + } + } + + pub fn modify(&mut self, at: FilePointer) -> (FileRange, &mut T) { + self.modify_range(at.range(size_of::() as u64)) + } + + pub fn allocate_size(&mut self, length: u64) -> (FileRange, &mut T) { + unsafe { + let (ptr, _) = self.allocate_raw(length); + (ptr, self.db.modify_range(ptr)) + } + } + + pub fn allocate(&mut self) -> (FileRange, &mut T) { + unsafe { + let (ptr, _) = self.allocate_raw(size_of::() as u64); + (ptr, self.db.modify_range(ptr)) + } + } + + pub fn free(&mut self, range: FileRange) { + let mut freed = false; + + if let Some(allocation) = self.new.remove(&range.start) { + assert_eq!(allocation, range); + self.db.free(range); + freed = true; + } + + for (_, replaced) in &mut self.replaced { + if replaced.from == range || replaced.to == Some(range) { + replaced.to = None; + freed = true; + } + } + + if !freed { + let res = self.replaced.insert( + range.start, + Replaced { + from: range, + to: None, + }, + ); + assert!(res.is_none()); + } + } + + fn allocate_range(&mut self, size: u64) -> FileRange { + self.db.allocate(size) + } + + pub fn root(&self) -> FilePointer { + self.db.header().root + } +}