mod blockindex; use crate::{ InvalidDeterminant, clock::{self, Counter, Timer}, }; use blockindex::BlockIndex; use std::io::Write; #[repr(u8)] #[non_exhaustive] #[derive(Debug)] pub enum SSToken { Start = 0, NewBlock = 1, NewSuperblock = 2, SuperblockSeq = 3, } impl TryFrom for SSToken { type Error = InvalidDeterminant; fn try_from(value: u8) -> std::result::Result { match value { 0 => Ok(SSToken::Start), 1 => Ok(SSToken::NewBlock), 2 => Ok(SSToken::NewSuperblock), 3 => Ok(SSToken::SuperblockSeq), _ => Err(InvalidDeterminant(value)), } } } impl From for u8 { fn from(value: SSToken) -> Self { match value { SSToken::Start => 0, SSToken::NewBlock => 1, SSToken::NewSuperblock => 2, SSToken::SuperblockSeq => 3, } } } pub(crate) struct Ctx { block_size: u32, superblock_size: u32, last_state: Vec, last_superseq: Vec, block_index: BlockIndex, superblock_index: BlockIndex, use_encode_state_comparisons: bool, } impl Ctx { pub fn new(block_size: u32, superblock_size: u32) -> Self { Self { block_size, superblock_size, last_state: vec![], last_superseq: vec![], block_index: BlockIndex::new(block_size as usize), superblock_index: BlockIndex::new(superblock_size as usize), use_encode_state_comparisons: true, } } } pub(crate) struct Decoder<'r, 'c, R: std::io::Read> { reader: &'r mut R, ctx: &'c mut Ctx, state_size: usize, finished: bool, readout_cursor: usize, } impl<'r, 'c, R: std::io::Read> Decoder<'r, 'c, R> { pub(crate) fn new(reader: &'r mut R, ctx: &'c mut Ctx, state_size: usize) -> Self { Self { reader, ctx, finished: false, readout_cursor: 0, state_size, } } fn readout(&mut self, mut buf: &mut [u8]) -> std::io::Result { match buf.write(&self.ctx.last_state[self.readout_cursor..]) { Err(e) => Err(e), Ok(sz) => { self.readout_cursor += sz; Ok(sz) } } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ParseState { WaitForStart, WaitForSuperblockSeq, Finished, } #[derive(thiserror::Error, Debug)] enum SSError { #[error("Invalid token {0}")] InvalidToken(#[from] InvalidDeterminant), #[error("Too many start tokens in stream")] TooManyStarts(), #[error("Unexpected {1:?} during {0:?}")] ParseError(ParseState, SSToken), #[error("Block {0} is the wrong size")] BlockWrongSize(u32), #[error("Superblock {0} is the wrong size")] SuperblockWrongSize(u32), #[error("Couldn't insert block at {1} on frame {0}")] BadBlockInsert(u64, u32), #[error("Couldn't insert superblock at {1} on frame {0}")] BadSuperblockInsert(u64, u32), } impl std::io::Read for Decoder<'_, '_, R> { /* a slightly degenerate read implementation in that it will keep * calling read on the inner reader until a complete checkpoint is * read, then return 0 for subsequent reads */ #[allow(clippy::too_many_lines)] fn read(&mut self, outbuf: &mut [u8]) -> std::io::Result { use ParseState as State; use rmp::decode as r; if self.finished { if self.readout_cursor == self.state_size { return Ok(0); } return self.readout(outbuf); } let stopwatch = clock::time(Timer::DecodeStatestream); let mut frame = 0; let mut state = State::WaitForStart; let mut buf = vec![0_u8; self.ctx.block_size as usize]; let mut superblock = vec![0_u32; self.ctx.superblock_size as usize]; loop { let tok: u8 = r::read_int(self.reader).map_err(std::io::Error::other)?; match ( state, SSToken::try_from(tok) .map_err(|e| std::io::Error::other(SSError::InvalidToken(e)))?, ) { (State::WaitForStart, SSToken::Start) => { frame = r::read_int(self.reader).map_err(std::io::Error::other)?; state = State::WaitForSuperblockSeq; } (_, SSToken::Start) => return Err(std::io::Error::other(SSError::TooManyStarts())), (State::WaitForSuperblockSeq, SSToken::NewBlock) => { let idx = r::read_int(self.reader).map_err(std::io::Error::other)?; let bin_len = r::read_bin_len(self.reader).map_err(std::io::Error::other)?; if bin_len != self.ctx.block_size { return Err(std::io::Error::other(SSError::BlockWrongSize(bin_len))); } self.reader.read_exact(&mut buf)?; // hashes += 1; if !self .ctx .block_index .insert_exact(idx, Box::from(buf.clone()), frame) { return Err(std::io::Error::other(SSError::BadBlockInsert(frame, idx))); } } (State::WaitForSuperblockSeq, SSToken::NewSuperblock) => { let idx = r::read_int(self.reader).map_err(std::io::Error::other)?; let arr_len = r::read_array_len(self.reader).map_err(std::io::Error::other)?; if arr_len != self.ctx.superblock_size { return Err(std::io::Error::other(SSError::SuperblockWrongSize(arr_len))); } for superblock_elt in &mut superblock { *superblock_elt = r::read_int(self.reader).map_err(std::io::Error::other)?; } // hashes += 1; if !self.ctx.superblock_index.insert_exact( idx, Box::from(superblock.clone()), frame, ) { return Err(std::io::Error::other(SSError::BadSuperblockInsert( frame, idx, ))); } } (State::WaitForSuperblockSeq, SSToken::SuperblockSeq) => { let arr_len = r::read_array_len(self.reader).map_err(std::io::Error::other)? as usize; let last_state_valid = self.ctx.last_superseq.len() >= arr_len && self.ctx.last_state.len() >= self.state_size; let block_byte_size = self.ctx.block_size as usize; let superblock_byte_size = self.ctx.superblock_size as usize * block_byte_size; let mut superseq = vec![0; arr_len]; self.ctx.last_state.resize(self.state_size, 0); let mut skipped_superblocks = 0; let mut skipped_blocks = 0; for (superblock_i, superseq_sblk) in superseq.iter_mut().enumerate() { let superblock_idx = r::read_int(self.reader).map_err(std::io::Error::other)?; *superseq_sblk = superblock_idx; if last_state_valid && self.ctx.last_superseq[superblock_i] == superblock_idx { // no need to copy bytes skipped_superblocks += 1; continue; } let superblock_data = self.ctx.superblock_index.get(superblock_idx); for (block_i, block_id) in superblock_data.iter().copied().enumerate() { if last_state_valid && self .ctx .superblock_index .get(self.ctx.last_superseq[superblock_i])[block_i] == block_id { // no need to copy bytes skipped_blocks += 1; continue; } let block_start = (superblock_i * superblock_byte_size + block_i * block_byte_size) .min(self.state_size); let block_end = (block_start + block_byte_size).min(self.state_size); let block_bytes = self.ctx.block_index.get(block_id); if block_end <= block_start { // This can happen in the last superblock if it was padded with extra blocks break; } self.ctx.last_state[block_start..block_end] .copy_from_slice(&block_bytes[0..(block_end - block_start)]); } } clock::count(Counter::DecSkippedSuperblocks, skipped_superblocks); clock::count(Counter::DecSkippedBlocks, skipped_blocks); self.ctx.last_superseq = superseq; state = State::Finished; self.finished = true; break; } (s, tok) => return Err(std::io::Error::other(SSError::ParseError(s, tok))), } } assert_eq!(state, State::Finished); drop(stopwatch); self.readout(outbuf) } } pub(crate) struct Encoder<'w, 'c, W: std::io::Write> { writer: &'w mut W, ctx: &'c mut Ctx, } /* Does not include the size of the str,arr,map,ext contents */ fn rmp_size(m: rmp::Marker) -> usize { #[allow( clippy::enum_glob_use, reason = "If any new variants are added, the match will cease to be exhaustive" )] use rmp::Marker::*; match m { FixPos(_) | FixNeg(_) | Null | Reserved | False | True | FixMap(_) | FixArray(_) | FixStr(_) => 1, U8 | I8 | Bin8 | Str8 | FixExt1 | FixExt2 | FixExt4 | FixExt8 | FixExt16 => 2, U16 | I16 | Bin16 | Ext8 | Str16 | Array16 | Map16 => 3, Ext16 => 4, U32 | I32 | F32 | Bin32 | Str32 | Array32 | Map32 => 5, Ext32 => 6, U64 | I64 | F64 => 9, } } impl<'w, 'c, W: std::io::Write> Encoder<'w, 'c, W> { pub(crate) fn new(writer: &'w mut W, ctx: &'c mut Ctx) -> Self { Self { writer, ctx } } #[allow(clippy::too_many_lines)] pub fn encode_checkpoint(mut self, checkpoint: &[u8], frame: u64) -> std::io::Result { use rmp::encode as r; let stopwatch = clock::time(Timer::EncodeStatestream); clock::count(Counter::EncTotalKBsIn, (checkpoint.len() / 1024) as u64); let mut bytes_out = 0; bytes_out += rmp_size(r::write_uint( &mut self.writer, u64::from(u8::from(SSToken::Start)), )?); bytes_out += rmp_size(r::write_uint(&mut self.writer, frame)?); let block_size = self.ctx.block_size as usize; let mut padded_block = vec![0; block_size]; let superblock_size = self.ctx.superblock_size as usize; let superblock_size_bytes = block_size * superblock_size; let superblock_count = ((checkpoint.len() - 1) / superblock_size_bytes) + 1; clock::count(Counter::EncTotalSuperblocks, superblock_count as u64); clock::count( Counter::EncTotalBlocks, (((checkpoint.len() - 1) / block_size) + 1) as u64, ); let mut reused_blocks = 0; let mut reused_superblocks = 0; let mut hashes = 0; let mut skipped_blocks = 0; let mut memcmps = 0; self.ctx .last_superseq .resize(superblock_count.max(self.ctx.last_superseq.len()), 0); let mut superblock_contents = vec![0_u32; superblock_size]; let can_compare_saves = if self.ctx.last_state.len() < checkpoint.len() { self.ctx.last_state.clear(); self.ctx.last_state.extend_from_slice(checkpoint); false } else { self.ctx.last_state.truncate(checkpoint.len()); self.ctx.use_encode_state_comparisons }; for (superblock_i, (superblock_bytes, last_state_superblock_bytes)) in (checkpoint .chunks(superblock_size_bytes) .zip(self.ctx.last_state.chunks(superblock_size_bytes))) .enumerate() { /* maybe: skip superblocks */ if superblock_bytes.len() < superblock_size_bytes { let block_count = (superblock_bytes.len() - 1) / block_size + 1; if block_count + 1 < superblock_size { superblock_contents[(block_count + 1)..].fill(0); } } for (block_i, (block_bytes, last_state_block_bytes)) in (superblock_bytes .chunks(block_size) .zip(last_state_superblock_bytes.chunks(block_size))) .enumerate() { memcmps += u64::from(can_compare_saves); let found_block = if can_compare_saves && block_bytes[..] == last_state_block_bytes[..block_bytes.len()] { skipped_blocks += 1; blockindex::Insertion { index: self .ctx .superblock_index .get(self.ctx.last_superseq[superblock_i])[block_i], is_new: false, } } else if block_bytes.len() < block_size { padded_block[block_bytes.len()..].fill(0); padded_block[..block_bytes.len()].copy_from_slice(block_bytes); hashes += 1; self.ctx.block_index.insert(&padded_block, frame) } else { hashes += 1; self.ctx.block_index.insert(block_bytes, frame) }; superblock_contents[block_i] = found_block.index; if found_block.is_new { let block_out_bytes = self.ctx.block_index.get(found_block.index); bytes_out += rmp_size(r::write_uint( self.writer, u64::from(u8::from(SSToken::NewBlock)), )?); bytes_out += rmp_size(r::write_uint(self.writer, u64::from(found_block.index))?); bytes_out += rmp_size(r::write_bin_len(self.writer, self.ctx.block_size)?); self.writer.write_all(block_out_bytes)?; bytes_out += block_out_bytes.len(); } else { reused_blocks += 1; } } hashes += 1; let found_superblock = self .ctx .superblock_index .insert(&superblock_contents, frame); self.ctx.last_superseq[superblock_i] = found_superblock.index; if found_superblock.is_new { bytes_out += rmp_size(r::write_uint( self.writer, u64::from(u8::from(SSToken::NewSuperblock)), )?); bytes_out += rmp_size(r::write_uint( self.writer, u64::from(found_superblock.index), )?); bytes_out += rmp_size(r::write_array_len(self.writer, self.ctx.superblock_size)?); for blkid in &superblock_contents { bytes_out += rmp_size(r::write_uint(self.writer, u64::from(*blkid))?); } } else { reused_superblocks += 1; } } clock::count(Counter::EncReusedBlocks, reused_blocks); clock::count(Counter::EncReusedSuperblocks, reused_superblocks); clock::count(Counter::EncSkippedBlocks, skipped_blocks); clock::count(Counter::EncMemCmps, memcmps); clock::count(Counter::EncHashes, hashes); self.ctx.last_superseq.truncate(superblock_count); bytes_out += rmp_size(r::write_uint( self.writer, u64::from(u8::from(SSToken::SuperblockSeq)), )?); bytes_out += rmp_size(r::write_array_len( self.writer, u32::try_from(superblock_count) .map_err(|e| std::io::Error::other(crate::ReplayError::CheckpointTooBig(e)))?, )?); for super_id in &self.ctx.last_superseq { bytes_out += rmp_size(r::write_uint(self.writer, u64::from(*super_id))?); } drop(stopwatch); clock::count(Counter::EncTotalKBsOut, (bytes_out / 1024) as u64); u32::try_from(bytes_out) .map_err(|e| std::io::Error::other(crate::ReplayError::CheckpointTooBig(e))) } }