From 7f16427d5041a787ff5ad9bf61cb3bd4e7dad5bb Mon Sep 17 00:00:00 2001 From: "Joseph C. Osborn" Date: Tue, 28 Oct 2025 15:38:00 -0700 Subject: [PATCH] wip --- src/bin/reencode.rs | 44 +++++ src/rply.rs | 382 +++++++++++++++++++++++++++++++++++++++++++- src/statestream.rs | 19 +++ 3 files changed, 437 insertions(+), 8 deletions(-) create mode 100644 src/bin/reencode.rs diff --git a/src/bin/reencode.rs b/src/bin/reencode.rs new file mode 100644 index 0000000..bbef427 --- /dev/null +++ b/src/bin/reencode.rs @@ -0,0 +1,44 @@ +use rply_codec::{Frame, decode, encode}; + +fn main() { + let args: Vec<_> = std::env::args().collect(); + let file = + std::fs::File::open(args.get(1).unwrap_or(&"examples/bobl.replay".to_string())).unwrap(); + let outfile = std::fs::File::open( + args.get(2) + .unwrap_or(&"examples/bobl_smallblocks.replay".to_string()), + ) + .unwrap(); + let mut file = std::io::BufReader::new(file); + let mut outfile = std::io::BufWriter::new(outfile); + let mut rply = decode(&mut file).unwrap(); + let header = &rply.header; + println!("{header:?}"); + let mut header_out = header.clone(); + header_out.set_block_size(64); + let mut out = encode(header_out, &rply.initial_state, &mut outfile).unwrap(); + let mut frame = Frame::default(); + while let Ok(()) = rply + .read_frame(&mut frame) + .inspect_err(|e| println!("Err: {e}")) + { + println!( + " {}{:08} {}", + if frame.checkpoint_bytes.is_empty() { + " " + } else { + "*" + }, + rply.frame_number, + frame.inputs(), + ); + out.write_frame(&frame).unwrap(); + if Some(rply.frame_number) == rply.header.frame_count() { + println!("Done!"); + break; + } + } + out.finish().unwrap(); + assert_eq!(out.frame_number, rply.frame_number); + assert_eq!(out.header.frame_count(), rply.header.frame_count()); +} diff --git a/src/rply.rs b/src/rply.rs index 5124831..cff46ca 100644 --- a/src/rply.rs +++ b/src/rply.rs @@ -1,3 +1,5 @@ +use std::io::Write; + use crate::{InvalidDeterminant, statestream}; use thiserror::Error; @@ -19,7 +21,7 @@ use thiserror::Error; // HeaderLen = 40, // } // const HEADER_V0V1_LEN_BYTES: usize = HeaderV0V1Part::HeaderLen as usize; -// const HEADER_LEN_BYTES: usize = HeaderV2Part::HeaderLen as usize; +const HEADERV2_LEN_BYTES: usize = 40; // const VERSION: u32 = 2; const MAGIC: u32 = 0x4253_5632; @@ -43,6 +45,16 @@ impl From for FrameToken { } } } +impl From for u8 { + fn from(value: FrameToken) -> Self { + match value { + FrameToken::Invalid => 0, + FrameToken::Regular => b'f', + FrameToken::Checkpoint => b'c', + FrameToken::Checkpoint2 => b'C', + } + } +} #[repr(u8)] #[non_exhaustive] @@ -66,6 +78,16 @@ impl TryFrom for Compression { } } +impl From for u8 { + fn from(value: Compression) -> Self { + match value { + Compression::None => 0, + Compression::Zlib => 1, + Compression::Zstd => 2, + } + } +} + #[repr(u8)] #[non_exhaustive] #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -86,7 +108,16 @@ impl TryFrom for Encoding { } } -#[derive(Debug)] +impl From for u8 { + fn from(value: Encoding) -> Self { + match value { + Encoding::Raw => 0, + Encoding::Statestream => 1, + } + } +} + +#[derive(Debug, Clone)] pub struct HeaderBase { pub version: u32, pub content_crc: u32, @@ -94,7 +125,7 @@ pub struct HeaderBase { pub identifier: u64, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct HeaderV2 { pub base: HeaderBase, pub frame_count: u32, @@ -105,7 +136,7 @@ pub struct HeaderV2 { pub checkpoint_compression: Compression, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Header { V0V1(HeaderBase), V2(HeaderV2), @@ -127,6 +158,12 @@ pub enum ReplayError { NoCoreRead(), #[error("Checkpoint too big {0}")] CheckpointTooBig(std::num::TryFromIntError), + #[error("Frame too long {0}")] + FrameTooLong(std::num::TryFromIntError), + #[error("Frame has too many key events {0}")] + TooManyKeyEvents(std::num::TryFromIntError), + #[error("Frame has too many input events {0}")] + TooManyInputEvents(std::num::TryFromIntError), #[error("Invalid frame token {0}")] BadFrameToken(u8), } @@ -381,13 +418,272 @@ pub fn decode(rply: &mut R) -> Result> ReplayDecoder::new(rply) } +pub struct ReplayEncoder<'a, W: std::io::Write + std::io::Seek> { + rply: &'a mut W, + pub header: Header, + pub initial_state: Vec, + pub frame_number: u64, + last_pos: u64, + ss_state: statestream::Ctx, + finished: bool, +} + +impl<'w, W: std::io::Write + std::io::Seek> ReplayEncoder<'w, W> { + /// Creates a [`ReplayEncoder`] for the given writable and seekable stream. + /// + /// # Errors + /// [`ReplayError::IO`]: Some issue with the write stream, e.g. unexpected end + /// [`ReplayError::Version`]: Version identifier not supported by writer + /// [`ReplayError::Compression`]: Unsupported compression scheme for checkpoints + pub fn new<'s>( + header: Header, + initial_state: &'s [u8], + rply: &'w mut W, + ) -> Result> { + if header.version() != 2 { + return Err(ReplayError::Version(header.version())); + } + let ss_state = statestream::Ctx::new(header.block_size(), header.superblock_size()); + let mut replay = ReplayEncoder { + rply, + header, + initial_state: vec![], + frame_number: 0, + last_pos: 0, + ss_state, + finished: false, + }; + replay.write_header()?; + if !initial_state.is_empty() { + replay.encode_initial_checkpoint(initial_state)?; + } + replay.last_pos = replay.rply.stream_position()?; + Ok(replay) + } + fn write_header(&mut self) -> Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + self.header + .set_frame_count(u32::try_from(self.frame_number).unwrap_or_default()); + let old_pos = self.rply.stream_position()?; + self.rply.seek(std::io::SeekFrom::Start(0))?; + self.rply.write_u32::(MAGIC)?; + self.rply.write_u32::(2)?; + self.rply + .write_u32::(self.header.content_crc())?; + // state size + self.rply + .write_u32::(self.header.initial_state_size())?; + self.rply + .write_u64::(self.header.identifier())?; + self.rply + .write_u32::(self.header.block_size())?; + self.rply + .write_u32::(self.header.superblock_size())?; + let cp_interval = u32::from(self.header.checkpoint_commit_interval()); + let cp_threshold = u32::from(self.header.checkpoint_commit_threshold()); + let cp_compression = u32::from(u8::from(self.header.checkpoint_compression())); + self.rply.write_u32::( + (cp_interval << 24) | (cp_threshold << 16) | (cp_compression << 8), + )?; + self.rply.seek(std::io::SeekFrom::Start(old_pos))?; + Ok(()) + } + fn encode_checkpoint(&mut self, checkpoint: &[u8], frame: u64) -> Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + let compression = self.header.checkpoint_compression(); + let encoding = Encoding::Statestream; + self.rply.write_u8(u8::from(compression))?; + self.rply.write_u8(u8::from(encoding))?; + // write unencoded uncompressed size + let full_size = u32::try_from(checkpoint.len()).map_err(ReplayError::CheckpointTooBig)?; + self.rply.write_u32::(full_size)?; + let size_pos = self.rply.stream_position()?; + // can't yet write encoded uncompressed size, just write zeros for now + // write encoded compressed size + self.rply.write_u32::(0)?; + // write encoded compressed bytes + self.rply.write_u32::(0)?; + let (encoded_size, compressed_size) = match (compression, encoding) { + (Compression::None, Encoding::Raw) => { + self.rply.write_all(checkpoint)?; + (full_size, full_size) + } + (Compression::None, Encoding::Statestream) => { + let encoder = statestream::Encoder::new(&mut self.rply, &mut self.ss_state); + let encoded_size = encoder.encode_checkpoint(checkpoint, frame)?; + (encoded_size, encoded_size) + } + (Compression::Zlib, Encoding::Raw) => { + use flate2::write::ZlibEncoder; + let here_pos = self.rply.stream_position()?; + let mut encoder = ZlibEncoder::new(&mut self.rply, flate2::Compression::default()); + let encoded_size = full_size; + encoder.write_all(checkpoint)?; + encoder.finish()?; + let compressed_size = u32::try_from(self.rply.stream_position()? - here_pos) + .map_err(ReplayError::CheckpointTooBig)?; + (encoded_size, compressed_size) + } + (Compression::Zlib, Encoding::Statestream) => { + use flate2::write::ZlibEncoder; + let here_pos = self.rply.stream_position()?; + let mut compressor = + ZlibEncoder::new(&mut self.rply, flate2::Compression::default()); + let encoder = statestream::Encoder::new(&mut compressor, &mut self.ss_state); + let encoded_size = encoder.encode_checkpoint(checkpoint, frame)?; + compressor.finish()?; + let compressed_size = u32::try_from(self.rply.stream_position()? - here_pos) + .map_err(ReplayError::CheckpointTooBig)?; + (encoded_size, compressed_size) + } + (Compression::Zstd, Encoding::Raw) => { + let here_pos = self.rply.stream_position()?; + let mut encoder = zstd::Encoder::new(&mut self.rply, 16)?; + encoder.write_all(checkpoint)?; + encoder.finish()?; + let encoded_size = full_size; + let compressed_size = u32::try_from(self.rply.stream_position()? - here_pos) + .map_err(ReplayError::CheckpointTooBig)?; + (encoded_size, compressed_size) + } + (Compression::Zstd, Encoding::Statestream) => { + let here_pos = self.rply.stream_position()?; + let mut compressor = zstd::Encoder::new(&mut self.rply, 16)?; + let encoder = statestream::Encoder::new(&mut compressor, &mut self.ss_state); + let encoded_size = encoder.encode_checkpoint(checkpoint, frame)?; + compressor.finish()?; + let compressed_size = u32::try_from(self.rply.stream_position()? - here_pos) + .map_err(ReplayError::CheckpointTooBig)?; + (encoded_size, compressed_size) + } + }; + let end_pos = self.rply.stream_position()?; + self.rply.seek(std::io::SeekFrom::Start(size_pos))?; + // write encoded compressed size + self.rply.write_u32::(encoded_size)?; + // write encoded compressed bytes + self.rply.write_u32::(compressed_size)?; + self.rply.seek(std::io::SeekFrom::Start(end_pos))?; + Ok(()) + } + fn encode_initial_checkpoint(&mut self, checkpoint: &[u8]) -> Result<()> { + let initial = std::mem::take(&mut self.initial_state); + let old_pos = self.rply.stream_position()?; + self.rply + .seek(std::io::SeekFrom::Start(HEADERV2_LEN_BYTES as u64))?; + self.encode_checkpoint(checkpoint, 0)?; + self.header.set_initial_state_size(initial.len() as u32); + self.initial_state = initial; + // Have to rewrite header to account for initial state size + self.write_header()?; + self.last_pos = self.rply.stream_position()?; + self.rply.seek(std::io::SeekFrom::Start(old_pos))?; + Ok(()) + } + /// Writes a single frame at the current encoder position. + pub fn write_frame(&mut self, frame: &Frame) -> Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + let start_pos = self.rply.stream_position()?; + self.rply.write_u32::( + u32::try_from(start_pos - self.last_pos).map_err(ReplayError::FrameTooLong)?, + )?; + self.rply.write_u8( + u8::try_from(frame.key_events.len()).map_err(ReplayError::TooManyKeyEvents)?, + )?; + for evt in &frame.key_events { + self.rply.write_u8(evt.down)?; + self.rply.write_u8(0)?; // padding + self.rply.write_u16::(evt.modf)?; + self.rply.write_u32::(evt.code)?; + self.rply.write_u32::(evt.chr)?; + } + self.rply.write_u16::( + u16::try_from(frame.input_events.len()).map_err(ReplayError::TooManyInputEvents)?, + )?; + for evt in &frame.input_events { + self.rply.write_u8(evt.port)?; + self.rply.write_u8(evt.device)?; + self.rply.write_u8(evt.idx)?; + self.rply.write_u8(0)?; // padding + self.rply.write_u16::(evt.id)?; + self.rply.write_i16::(evt.val)?; + } + if frame.checkpoint_bytes.is_empty() { + self.rply.write_u8(u8::from(FrameToken::Regular))?; + } else { + self.rply.write_u8(u8::from(FrameToken::Checkpoint2))?; + self.encode_checkpoint(&frame.checkpoint_bytes, self.frame_number)?; + } + self.frame_number += 1; + self.last_pos = start_pos; + Ok(()) + } + /// Finishes the encoding, writing the header in the process + pub fn finish(&mut self) -> Result<()> { + if self.finished { + return Ok(()); + } + self.write_header()?; + self.finished = true; + Ok(()) + } +} + +impl Drop for ReplayEncoder<'_, W> { + fn drop(&mut self) { + self.finish().unwrap(); + } +} + +/// Creates a [`ReplayEncoder`] for the given writable & seekable stream. +/// +/// # Errors +/// See [`ReplayEncoder::new`]. +pub fn encode<'w, 's, W: std::io::Write + std::io::Seek>( + header: Header, + initial_state: &'s [u8], + rply: &'w mut W, +) -> Result> { + ReplayEncoder::new(header, initial_state, rply) +} + impl Header { + fn base(&self) -> &HeaderBase { + match self { + Header::V0V1(header_base) => header_base, + Header::V2(header_v2) => &header_v2.base, + } + } + fn base_mut(&mut self) -> &mut HeaderBase { + match self { + Header::V0V1(header_base) => header_base, + Header::V2(header_v2) => &mut header_v2.base, + } + } #[must_use] pub fn version(&self) -> u32 { - match self { - Header::V0V1(header_base) => header_base.version, - Header::V2(header_v2) => header_v2.base.version, - } + self.base().version + } + #[must_use] + pub fn content_crc(&self) -> u32 { + self.base().content_crc + } + pub fn set_content_crc(&mut self, crc: u32) { + self.base_mut().content_crc = crc; + } + #[must_use] + pub fn identifier(&self) -> u64 { + self.base().identifier + } + pub fn set_identifier(&mut self, id: u64) { + self.base_mut().identifier = id; + } + #[must_use] + pub fn initial_state_size(&self) -> u32 { + self.base().initial_state_size + } + pub fn set_initial_state_size(&mut self, sz: u32) { + self.base_mut().initial_state_size = sz; } #[must_use] pub fn frame_count(&self) -> Option { @@ -396,6 +692,76 @@ impl Header { Header::V2(header_v2) => Some(u64::from(header_v2.frame_count)), } } + pub fn set_frame_count(&mut self, frames: u32) { + self.upgrade().frame_count = frames; + } + pub fn upgrade(&mut self) -> &mut HeaderV2 { + if let Header::V0V1(base) = self { + *self = Header::V2(HeaderV2 { + base: base.clone(), + frame_count: 0, + block_size: 0, + superblock_size: 0, + checkpoint_commit_interval: 0, + checkpoint_commit_threshold: 0, + checkpoint_compression: Compression::None, + }); + } + let Header::V2(v2) = self else { unreachable!() }; + v2 + } + #[must_use] + pub fn block_size(&self) -> u32 { + match self { + Header::V0V1(_) => 0, + Header::V2(header_v2) => header_v2.block_size, + } + } + pub fn set_block_size(&mut self, sz: u32) { + let v2 = self.upgrade(); + v2.block_size = sz; + } + #[must_use] + pub fn superblock_size(&self) -> u32 { + match self { + Header::V0V1(_) => 0, + Header::V2(header_v2) => header_v2.superblock_size, + } + } + pub fn set_superblock_size(&mut self, sz: u32) { + let v2 = self.upgrade(); + v2.superblock_size = sz; + } + #[must_use] + pub fn checkpoint_commit_interval(&self) -> u8 { + match self { + Header::V0V1(_) => 0, + Header::V2(header_v2) => header_v2.checkpoint_commit_interval, + } + } + #[must_use] + pub fn checkpoint_commit_threshold(&self) -> u8 { + match self { + Header::V0V1(_) => 0, + Header::V2(header_v2) => header_v2.checkpoint_commit_threshold, + } + } + pub fn set_checkpoint_commit_settings(&mut self, interval: u8, threshold: u8) { + let v2 = self.upgrade(); + v2.checkpoint_commit_interval = interval; + v2.checkpoint_commit_threshold = threshold; + } + #[must_use] + pub fn checkpoint_compression(&self) -> Compression { + match self { + Header::V0V1(_) => Compression::None, + Header::V2(header_v2) => header_v2.checkpoint_compression, + } + } + pub fn set_checkpoint_compression(&mut self, compression: Compression) { + let v2 = self.upgrade(); + v2.checkpoint_compression = compression; + } } #[derive(Debug, Default)] pub struct KeyData { diff --git a/src/statestream.rs b/src/statestream.rs index e80facf..6aa2adf 100644 --- a/src/statestream.rs +++ b/src/statestream.rs @@ -203,3 +203,22 @@ impl std::io::Read for Decoder<'_, '_, R> { self.readout(outbuf) } } + +pub(crate) struct Encoder<'w, 'c, W: std::io::Write> { + writer: &'w mut W, + ctx: &'c mut Ctx, +} + +impl<'w, 'c, W: std::io::Write> Encoder<'w, 'c, W> { + pub(crate) fn new(writer: &'w mut W, ctx: &'c mut Ctx) -> Self { + Self { writer, ctx } + } + pub fn encode_checkpoint(mut self, checkpoint: &[u8], frame: u64) -> std::io::Result { + use rmp::encode as r; + r::write_uint(&mut self.writer, u64::from(u8::from(SSToken::Start)))?; + r::write_uint(&mut self.writer, frame)?; + todo!(); + + Ok(0) + } +}