From 6f3111d06a3c5336c9d56602fac9d29899a14853 Mon Sep 17 00:00:00 2001 From: Ben Pfaff Date: Mon, 21 Jul 2025 18:44:08 -0700 Subject: [PATCH] cleanup --- rust/pspp/src/sys/write.rs | 291 ++++++++++++++++++++----------------- 1 file changed, 158 insertions(+), 133 deletions(-) diff --git a/rust/pspp/src/sys/write.rs b/rust/pspp/src/sys/write.rs index 369bea3f25..cfb25a5897 100644 --- a/rust/pspp/src/sys/write.rs +++ b/rust/pspp/src/sys/write.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] -use core::f64; use std::{ borrow::Cow, collections::HashMap, @@ -12,6 +10,7 @@ use std::{ use binrw::{BinWrite, Endian, Error as BinError}; use chrono::Local; +use either::Either; use encoding_rs::Encoding; use flate2::write::ZlibEncoder; use itertools::zip_eq; @@ -94,13 +93,17 @@ impl WriteOptions { self, dictionary: &Dictionary, path: impl AsRef, - ) -> Result { + ) -> Result>, BinError> { self.write_writer(dictionary, BufWriter::new(File::create(path)?)) } /// Writes `dictionary` to `writer` in system file format. Returns a /// [Writer] that can be used for writing cases to the new file. - pub fn write_writer(self, dictionary: &Dictionary, mut writer: W) -> Result + pub fn write_writer( + self, + dictionary: &Dictionary, + mut writer: W, + ) -> Result, BinError> where W: Write + Seek + 'static, { @@ -647,14 +650,6 @@ impl<'a> Padded<'a> { padding: Pad::new(length - min, pad), } } - - pub fn to_multiple(bytes: &'a [u8], multiple: usize, pad: u8) -> Self { - let length = bytes.len().next_multiple_of(multiple); - Self { - padding: Pad::new(length - bytes.len(), pad), - bytes, - } - } } pub struct Pad { @@ -744,7 +739,6 @@ impl Iterator for SegmentWidths { enum CaseVar { Numeric, String { - width: usize, encoding: SmallVec<[StringSegment; 1]>, }, } @@ -770,24 +764,10 @@ impl CaseVar { encoding.last_mut().unwrap().padding_bytes += padding_bytes; } } - CaseVar::String { - width: w as usize, - encoding, - } + CaseVar::String { encoding } } } } - - fn bytes(&self) -> usize { - match self { - CaseVar::Numeric => 8, - CaseVar::String { width: _, encoding } => encoding - .iter() - .map(|segment| segment.data_bytes + segment.padding_bytes) - .sum(), - } - } - fn n_segments(&self) -> usize { match self { CaseVar::Numeric => 1, @@ -797,26 +777,42 @@ impl CaseVar { } /// System file writer. -pub struct Writer { +pub struct Writer +where + W: Write + Seek, +{ compression: Option, case_vars: Vec, - inner: WriterInner, -} - -pub struct WriterInner { opcodes: Vec, data: Vec, - inner: Box, + inner: Option>>, + n_cases: u64, } -trait WriteSeek: Write + Seek {} -impl WriteSeek for T where T: Write + Seek {} +pub struct WriterInner<'a, W: Write> { + case_vars: &'a [CaseVar], + opcodes: &'a mut Vec, + data: &'a mut Vec, + inner: &'a mut W, +} -impl WriterInner { - fn finish(mut self) -> Result<(), BinError> { - self.flush_compressed() +impl<'a, W> WriterInner<'a, W> +where + W: Write + Seek, +{ + fn new( + case_vars: &'a [CaseVar], + opcodes: &'a mut Vec, + data: &'a mut Vec, + inner: &'a mut W, + ) -> Self { + Self { + case_vars, + opcodes, + data, + inner, + } } - fn flush_compressed(&mut self) -> Result<(), BinError> { if !self.opcodes.is_empty() { self.opcodes.resize(8, 0); @@ -834,103 +830,52 @@ impl WriterInner { self.opcodes.push(opcode); Ok(()) } -} - -impl Writer { - fn new(options: WriteOptions, case_vars: Vec, inner: W) -> Result - where - W: Write + Seek + 'static, - { - Ok(Self { - compression: options.compression, - case_vars, - inner: WriterInner { - opcodes: Vec::with_capacity(8), - data: Vec::with_capacity(64), - inner: match options.compression { - Some(Compression::ZLib) => Box::new(ZlibWriter::new(inner)?), - _ => Box::new(inner), - }, - }, - }) - } - - /// Finishes writing the file, flushing buffers and updating headers to - /// match the final case counts. - pub fn finish(mut self) -> Result<(), BinError> { - self.try_finish() - } - - /// Tries to finish writing the file, flushing buffers and updating headers - /// to match the final case counts. - /// - /// # Panic - /// - /// Attempts to write more cases after calling this function may result in a - /// panic. - pub fn try_finish(&mut self) -> Result<(), BinError> { - self.inner.flush_compressed() - } - /// Writes `case` to the system file. - pub fn write_case<'a>( - &mut self, - case: impl IntoIterator, - ) -> Result<(), BinError> { - match self.compression { - Some(_) => self.write_case_compressed(case.into_iter()), - None => self.write_case_uncompressed(case.into_iter()), - } - } - fn write_case_uncompressed<'a>( + fn write_case_uncompressed<'c>( &mut self, - case: impl Iterator, + case: impl Iterator, ) -> Result<(), BinError> { - for (var, datum) in zip_eq(&self.case_vars, case) { + for (var, datum) in zip_eq(self.case_vars, case) { match var { CaseVar::Numeric => datum .as_number() .unwrap() .unwrap_or(f64::MIN) - .write_le(&mut self.inner.inner)?, - CaseVar::String { width: _, encoding } => { + .write_le(&mut self.inner)?, + CaseVar::String { encoding } => { let mut s = datum.as_string().unwrap().as_bytes(); for segment in encoding { let data; (data, s) = s.split_at(segment.data_bytes); - (data, Pad::new(segment.padding_bytes, 0)) - .write_le(&mut self.inner.inner)?; + (data, Pad::new(segment.padding_bytes, 0)).write_le(&mut self.inner)?; } } } } Ok(()) } - fn write_case_compressed<'a>( + fn write_case_compressed<'c>( &mut self, - case: impl Iterator, + case: impl Iterator, ) -> Result<(), BinError> { - for (var, datum) in zip_eq(&self.case_vars, case) { + for (var, datum) in zip_eq(self.case_vars, case) { match var { CaseVar::Numeric => match datum.as_number().unwrap() { - None => self.inner.put_opcode(255)?, + None => self.put_opcode(255)?, Some(number) => { if number >= 1.0 - BIAS && number <= 251.0 - BIAS && number == number.trunc() { - self.inner.put_opcode((number + BIAS) as u8)? + self.put_opcode((number + BIAS) as u8)? } else { - self.inner.put_opcode(253)?; - - number - .write_le(&mut Cursor::new(&mut self.inner.data)) - .unwrap(); + self.put_opcode(253)?; + self.data.extend_from_slice(&number.to_le_bytes()); } } }, - CaseVar::String { width: _, encoding } => { + CaseVar::String { encoding } => { let mut s = datum.as_string().unwrap().as_bytes(); for segment in encoding { let data; @@ -939,23 +884,23 @@ impl Writer { let (chunks, remainder) = data.as_chunks::<8>(); for chunk in chunks { if chunk == b" " { - self.inner.put_opcode(254)?; + self.put_opcode(254)?; } else { - self.inner.put_opcode(253)?; - self.inner.data.extend_from_slice(chunk); + self.put_opcode(253)?; + self.data.extend_from_slice(chunk); } } if !remainder.is_empty() { if remainder.iter().all(|c| *c == b' ') { - self.inner.put_opcode(254)?; + self.put_opcode(254)?; } else { - self.inner.put_opcode(253)?; - self.inner.data.extend_from_slice(remainder); - self.inner.data.extend(repeat_n(0, 8 - remainder.len())); + self.put_opcode(253)?; + self.data.extend_from_slice(remainder); + self.data.extend(repeat_n(0, 8 - remainder.len())); } } for _ in 0..segment.padding_bytes / 8 { - self.inner.put_opcode(254)?; + self.put_opcode(254)?; } } } @@ -965,16 +910,108 @@ impl Writer { } } -impl Drop for Writer { +impl Writer +where + W: Write + Seek, +{ + fn new(options: WriteOptions, case_vars: Vec, inner: W) -> Result { + Ok(Self { + compression: options.compression, + case_vars, + opcodes: Vec::with_capacity(8), + data: Vec::with_capacity(64), + n_cases: 0, + inner: match options.compression { + Some(Compression::ZLib) => Some(Either::Right(ZlibWriter::new(inner)?)), + _ => Some(Either::Left(inner)), + }, + }) + } + + /// Finishes writing the file, flushing buffers and updating headers to + /// match the final case counts. + pub fn finish(mut self) -> Result<(), BinError> { + self.try_finish() + } + + /// Tries to finish writing the file, flushing buffers and updating headers + /// to match the final case counts. + /// + /// # Panic + /// + /// Attempts to write more cases after calling this function may will panic. + pub fn try_finish(&mut self) -> Result<(), BinError> { + let Some(inner) = self.inner.take() else { + return Ok(()); + }; + + let mut inner = match inner { + Either::Left(mut inner) => { + WriterInner::new( + &self.case_vars, + &mut self.opcodes, + &mut self.data, + &mut inner, + ) + .flush_compressed()?; + inner + } + Either::Right(mut zlib_writer) => { + WriterInner::new( + &self.case_vars, + &mut self.opcodes, + &mut self.data, + &mut zlib_writer, + ) + .flush_compressed()?; + zlib_writer.finish()? + } + }; + if let Ok(n_cases) = u32::try_from(self.n_cases) { + if inner.seek(SeekFrom::Start(80)).is_ok() { + let _ = inner.write_all(&n_cases.to_le_bytes()); + } + } + Ok(()) + } + + /// Writes `case` to the system file. + /// + /// # Panic + /// + /// Panics if [try_finish](Self::try_finish) has been called. + pub fn write_case<'a>( + &mut self, + case: impl IntoIterator, + ) -> Result<(), BinError> { + match self.inner.as_mut().unwrap() { + Either::Left(inner) => { + let mut inner = + WriterInner::new(&self.case_vars, &mut self.opcodes, &mut self.data, inner); + match self.compression { + Some(_) => inner.write_case_compressed(case.into_iter())?, + None => inner.write_case_uncompressed(case.into_iter())?, + } + } + Either::Right(inner) => { + WriterInner::new(&self.case_vars, &mut self.opcodes, &mut self.data, inner) + .write_case_compressed(case.into_iter())? + } + } + self.n_cases += 1; + Ok(()) + } +} + +impl Drop for Writer +where + W: Write + Seek, +{ fn drop(&mut self) { let _ = self.try_finish(); } } -struct Block { - uncompressed_size: u64, - compressed_size: u64, -} struct ZlibWriter where W: Write + Seek, @@ -1033,7 +1070,7 @@ where Ok(()) } - fn try_finish(&mut self) -> Result<(), BinError> { + fn finish(mut self) -> Result { self.flush_block()?; let ztrailer_offset = self.inner.stream_position()?; self.trailer.write_le(&mut self.inner)?; @@ -1043,20 +1080,8 @@ where ztrailer_len: self.trailer.len() as u64, }; self.inner.seek(SeekFrom::Start(header.zheader_offset))?; - header.write_le(&mut self.inner) - } - - fn finish(mut self) -> Result<(), BinError> { - self.try_finish() - } -} - -impl Drop for ZlibWriter -where - W: Write + Seek, -{ - fn drop(&mut self) { - let _ = self.try_finish(); + header.write_le(&mut self.inner)?; + Ok(self.inner) } } -- 2.30.2