diff --git a/src/allocator.rs b/src/allocator.rs index c5c07ff..a73807c 100644 --- a/src/allocator.rs +++ b/src/allocator.rs @@ -6,14 +6,14 @@ use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; use crate::{Db, FilePointer, FileRange, PagePointer, RawFilePointer, PAGE_SIZE, U16, U32, U64}; #[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum SlabKind { +pub enum SlabKind { SingleBytes, RelativeFreeList, AbsoluteFreeList, } impl SlabKind { - fn for_size(size: u32) -> Self { + pub fn for_size(size: u32) -> Self { if size == 1 { Self::SingleBytes } else if size < size_of::() as u32 { @@ -313,7 +313,7 @@ impl GeneralPurposeAllocator { } } -fn div_round_up(a: u64, b: u64) -> u64 { +pub(crate) fn div_round_up(a: u64, b: u64) -> u64 { (a + b - 1) / b } @@ -353,7 +353,10 @@ impl<'db, R> Iterator for SlabListIterator<'db, R> { impl SlabListHeader { pub fn capacity(&self) -> u32 { - (self.size.get() - size_of::() as u32) / size_of::() as u32 + (self.size() - size_of::() as u32) / size_of::() as u32 + } + pub fn size(&self) -> u32 { + self.size.get() } } @@ -364,7 +367,7 @@ impl SlabListPointer { (!ptr.0.is_null()).then_some(ptr) } - fn read_header(self, db: &Db) -> SlabListHeader { + pub fn read_header(self, db: &Db) -> SlabListHeader { unsafe { db.read(self.0) } } @@ -458,9 +461,9 @@ pub struct Slab { #[derive(Clone, Copy, FromBytes, FromZeroes, AsBytes, Unaligned)] #[repr(C)] -struct RelativeFreeListHeader { - next_page: PagePointer, - first: U16, +pub struct RelativeFreeListHeader { + pub next_page: PagePointer, + pub first: U16, } impl RelativeFreeListHeader { @@ -580,10 +583,14 @@ impl SlabPointer { } } - pub fn set_head(&self, db: &mut Db, next: RawFilePointer) { + fn set_head(&self, db: &mut Db, next: RawFilePointer) { self.modify(db).head = next; } + pub fn head(&self, db: &Db) -> RawFilePointer { + self.read(db).head + } + pub fn allocate_page(&self, db: &mut Db) -> RawFilePointer { let Slab { head, size } = self.read(db); diff --git a/src/lib.rs b/src/lib.rs index e51cb8b..491ed38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,6 +196,9 @@ impl PagePointer { fn null() -> Self { Self::nth(0) } + fn is_null(self) -> bool { + self == Self::null() + } } #[derive(Clone, Copy, FromBytes, FromZeroes, AsBytes, Unaligned, PartialEq, Eq)] @@ -286,6 +289,12 @@ pub struct Reader { state: Arc>>, } +impl Reader { + fn get(&self) -> Arc> { + self.state.get() + } +} + pub struct Db { file: File, map: MmapMut, @@ -616,9 +625,12 @@ impl Db { #[cfg(test)] mod tests { + use crate::allocator::{div_round_up, RelativeFreeListHeader, SlabKind}; + use super::*; use mapped::ReaderTrait; use std::io::Write; + use std::ops::Shl; use std::process::Stdio; #[derive(Debug, Clone, Copy)] @@ -829,10 +841,13 @@ mod tests { let mut snapshots = VecDeque::new(); for i in 0..20 { + dbg!(i); db.transaction(|transaction| { let root = transaction.root(); - let root = if root.is_null() { + let root = if !root.is_null() { + root + } else { let (root, data) = transaction.allocate::(); *data = DataHeader { @@ -840,8 +855,6 @@ mod tests { list: FilePointer::null(), }; - root - } else { root }; @@ -879,6 +892,23 @@ mod tests { root }); + validate_db(&db, |snaphot, coverage| { + coverage.set_allocated(snaphot.root.range()); + let data = snaphot.read(snaphot.root); + + let mut next = data.list; + while !next.is_null() { + coverage.set_allocated(next.range()); + next = snaphot.read(next).next; + } + + for SnapshotAndFreeList { to_free, .. } in &db.snapshots { + for &range in to_free { + coverage.set_range(range, CoverageKind::Free); + } + } + }); + snapshots.push_back(db.create_reader().state.get()); if snapshots.len() > 10 { drop(snapshots.pop_front()); @@ -886,6 +916,8 @@ mod tests { } } + // TODO: allocate some variably sized strings + for (i, snapshot) in snapshots.iter().enumerate() { let root = snapshot.read(snapshot.root); @@ -919,6 +951,283 @@ mod tests { // hexdump(db.map.as_bytes()); } + #[repr(u8)] + #[derive(Clone, Copy)] + enum CoverageKind { + Unaccounted = 0b00, + Allocated = 0b01, + Free = 0b10, + Metadata = 0b11, + } + + impl CoverageKind { + fn color(self) -> &'static str { + match self { + CoverageKind::Unaccounted => "31", + CoverageKind::Allocated => "32", + CoverageKind::Free => "34", + CoverageKind::Metadata => "35", + } + } + } + + impl CoverageKind { + fn from_bits(a: bool, b: bool) -> Self { + let res = match (a, b) { + (false, false) => Self::Unaccounted, + (false, true) => Self::Allocated, + (true, false) => Self::Free, + (true, true) => Self::Metadata, + }; + assert_eq!(res as u8, ((a as u8) << 1) + b as u8); + res + } + + fn to_bits(self) -> (bool, bool) { + (self as u8 & 0b10 != 0, self as u8 & 0b01 != 0) + } + } + + struct CoverageMap { + data_0: Vec, + data_1: Vec, + empty_bits: u8, + } + + impl CoverageMap { + fn new(len: usize) -> Self { + let bits = div_round_up(len as u64, 8) as usize; + Self { + data_0: vec![0; bits], + data_1: vec![0; bits], + empty_bits: (8 - len % 8) as u8, + } + } + + #[must_use] + fn set(&mut self, i: usize, kind: CoverageKind) -> bool { + let i_byte = i / 8; + let i_bit = i % 8; + let mask = 1 << i_bit; + + let (set_0, set_1) = kind.to_bits(); + + if i_byte >= self.data_0.len() { + return false; + } + + let is_set = self.data_0[i_byte] & mask != 0 || self.data_1[i_byte] & mask != 0; + + if is_set { + return false; + } + + self.data_0[i_byte] |= mask * set_0 as u8; + self.data_1[i_byte] |= mask * set_1 as u8; + + true + } + + #[must_use] + fn try_set_range(&mut self, range: FileRange, kind: CoverageKind) -> bool { + range.as_range().all(|i| self.set(i, kind)) + } + + fn set_allocated(&mut self, range: FileRange) { + self.set_range(range, CoverageKind::Allocated); + } + + fn set_range(&mut self, range: FileRange, kind: CoverageKind) { + assert!( + self.try_set_range(range, kind), + "possible allocator corruption" + ) + } + + fn all_covered(&self) -> bool { + let len = self.data_0.len(); + for (i, (&byte_0, &byte_1)) in self.data_0.iter().zip(self.data_1.iter()).enumerate() { + let byte = byte_0 | byte_1; + if i == len - 1 { + if byte != u8::MAX.overflowing_shl(self.empty_bits as u32).0 { + return false; + } + } else if byte != u8::MAX { + return false; + } + } + true + } + + fn set_color(res: &mut String, color: &str) { + res.push_str("\x1b["); + res.push_str(color); + res.push('m'); + } + + fn print(&self) -> String { + let mut res = String::new(); + + let mut prev = ""; + for (i, (&byte_0, &byte_1)) in self.data_0.iter().zip(self.data_1.iter()).enumerate() { + let byte = byte_0 | byte_1; + + fn all_equal(bits: u8) -> bool { + bits == 0 || bits == u8::MAX + } + + let kind = if all_equal(byte_0) && all_equal(byte_1) { + Some(CoverageKind::from_bits(byte_0 & 1 == 1, byte_1 & 1 == 1)) + } else { + None + }; + + if i != 0 { + if i as u64 % (PAGE_SIZE / 8 / 8) == 0 { + res.push('\n'); + } + if i as u64 % (PAGE_SIZE / 8) == 0 { + Self::set_color(&mut res, ""); + prev = ""; + res.push_str(&"-".repeat((PAGE_SIZE / 8 / 8) as usize)); + res.push('\n'); + } + } + + let color = kind.map(CoverageKind::color).unwrap_or("33"); + + if color != prev { + Self::set_color(&mut res, color); + } + prev = color; + + res.push(char::from_u32(0x2800 + byte as u32).unwrap()); + } + + Self::set_color(&mut res, ""); + + res.push('\n'); + + res + } + + fn assert_covered(&self) { + if !self.all_covered() { + panic!("Space in the file was lost\n{}", self.print()); + } + } + } + + #[test] + fn coverage_map_works() { + let mut coverage = CoverageMap::new(40); + assert!(!coverage.all_covered()); + assert!(coverage.try_set_range(RawFilePointer::null().range(20), CoverageKind::Metadata)); + assert!(!coverage.all_covered()); + assert!(coverage.try_set_range((RawFilePointer::null() + 20).range(20), CoverageKind::Free)); + assert!(coverage.all_covered()); + assert!(!coverage.try_set_range( + (RawFilePointer::null() + 40).range(8), + CoverageKind::Allocated + )); + assert!(!coverage.try_set_range( + (RawFilePointer::null() + 50).range(10), + CoverageKind::Allocated + )); + } + + fn validate_db(db: &Db, f: impl FnOnce(&Snapshot, &mut CoverageMap)) { + let mut coverage = CoverageMap::new(db.map.len()); + + let snapshot = &*db.state.get(); + + coverage.set_range(Db::::header_ptr().range(), CoverageKind::Metadata); + + // general purpose + { + let head = Db::::header_ptr().allocator_state_ptr().general_ptr(); + let mut next = *snapshot.read(head); + while !next.is_null() { + let size = GeneralPurposeAllocator::size(db, next); + coverage.set_range(next.into_raw().range(size), CoverageKind::Free); + next = *snapshot.read(next.next_ptr()); + } + } + + // slabs + { + let slabs = *snapshot.read(Db::::header_ptr().allocator_state_ptr().slabs_ptr()); + + let mut next = Some(slabs); + while let Some(slabs) = next { + coverage.set_range( + slabs + .0 + .into_raw() + .range(slabs.read_header(db).size() as u64), + CoverageKind::Metadata, + ); + next = slabs.next(db); + + for slab in slabs.iter(db) { + let size = slab.size(db); + let head = slab.head(db); + + match SlabKind::for_size(size) { + SlabKind::SingleBytes => todo!(), + SlabKind::RelativeFreeList => { + let (mut page, offset) = head.page_offset(); + + while !page.is_null() { + let header = + FilePointer::::new(page.start()); + + coverage.set_range(header.range(), CoverageKind::Metadata); + + let header = snapshot.read(header); + + page = header.next_page; + + let mut next = header.first.get(); + while next != 0 { + let next_ptr = FilePointer::::new( + RawFilePointer::from_page_and_offset(page, next), + ); + coverage.set_range( + next_ptr.into_raw().range(size as u64), + CoverageKind::Free, + ); + next = snapshot.read(next_ptr).get(); + } + } + + todo!(); + } + SlabKind::AbsoluteFreeList => { + let mut next = head; + while !next.is_null() { + let next_ptr = FilePointer::::new(next); + coverage.set_range( + next_ptr.into_raw().range(size as u64), + CoverageKind::Free, + ); + next = *snapshot.read(next_ptr); + } + } + } + } + } + } + + f(snapshot, &mut coverage); + + print!("{}", coverage.print()); + + if !coverage.all_covered() { + panic!("space in the file was lost..."); + } + } + fn hexdump(bytes: &[u8]) { let mut child = std::process::Command::new("hexdump") .arg("-C") diff --git a/src/transaction.rs b/src/transaction.rs index 7a64e05..16f586f 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -48,6 +48,10 @@ impl<'t, R> TransactionHandle<'t, R> { fn read_ptr_raw(&self, range: FileRange) -> FileRange { if let Some(&replaced) = self.replaced.get(&range.start) { assert_eq!(replaced.from, range); + + // TODO: replacing this with unwrap_or(range) + // will access the original region, which can't have actually been + // allocated, but was logically freed. replaced.to.expect("use after free") } else if let Some(&new) = self.new.get(&range.start) { assert_eq!(new, range);