cleanup
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 22 Jul 2025 01:44:08 +0000 (18:44 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Tue, 22 Jul 2025 01:44:08 +0000 (18:44 -0700)
rust/pspp/src/sys/write.rs

index 369bea3f25fc8a8fc362992e1e1233d96d239e02..cfb25a58973ae3cae8f542b17d168a60ddf2b2e0 100644 (file)
@@ -1,5 +1,3 @@
-#![allow(dead_code)]
-use core::f64;
 use std::{
     borrow::Cow,
     collections::HashMap,
@@ -12,6 +10,7 @@ use std::{
 
 use binrw::{BinWrite, Endian, Error as BinError};
 use chrono::Local;
+use either::Either;
 use encoding_rs::Encoding;
 use flate2::write::ZlibEncoder;
 use itertools::zip_eq;
@@ -94,13 +93,17 @@ impl WriteOptions {
         self,
         dictionary: &Dictionary,
         path: impl AsRef<Path>,
-    ) -> Result<Writer, BinError> {
+    ) -> Result<Writer<BufWriter<File>>, BinError> {
         self.write_writer(dictionary, BufWriter::new(File::create(path)?))
     }
 
     /// Writes `dictionary` to `writer` in system file format.  Returns a
     /// [Writer] that can be used for writing cases to the new file.
-    pub fn write_writer<W>(self, dictionary: &Dictionary, mut writer: W) -> Result<Writer, BinError>
+    pub fn write_writer<W>(
+        self,
+        dictionary: &Dictionary,
+        mut writer: W,
+    ) -> Result<Writer<W>, BinError>
     where
         W: Write + Seek + 'static,
     {
@@ -647,14 +650,6 @@ impl<'a> Padded<'a> {
             padding: Pad::new(length - min, pad),
         }
     }
-
-    pub fn to_multiple(bytes: &'a [u8], multiple: usize, pad: u8) -> Self {
-        let length = bytes.len().next_multiple_of(multiple);
-        Self {
-            padding: Pad::new(length - bytes.len(), pad),
-            bytes,
-        }
-    }
 }
 
 pub struct Pad {
@@ -744,7 +739,6 @@ impl Iterator for SegmentWidths {
 enum CaseVar {
     Numeric,
     String {
-        width: usize,
         encoding: SmallVec<[StringSegment; 1]>,
     },
 }
@@ -770,24 +764,10 @@ impl CaseVar {
                         encoding.last_mut().unwrap().padding_bytes += padding_bytes;
                     }
                 }
-                CaseVar::String {
-                    width: w as usize,
-                    encoding,
-                }
+                CaseVar::String { encoding }
             }
         }
     }
-
-    fn bytes(&self) -> usize {
-        match self {
-            CaseVar::Numeric => 8,
-            CaseVar::String { width: _, encoding } => encoding
-                .iter()
-                .map(|segment| segment.data_bytes + segment.padding_bytes)
-                .sum(),
-        }
-    }
-
     fn n_segments(&self) -> usize {
         match self {
             CaseVar::Numeric => 1,
@@ -797,26 +777,42 @@ impl CaseVar {
 }
 
 /// System file writer.
-pub struct Writer {
+pub struct Writer<W>
+where
+    W: Write + Seek,
+{
     compression: Option<Compression>,
     case_vars: Vec<CaseVar>,
-    inner: WriterInner,
-}
-
-pub struct WriterInner {
     opcodes: Vec<u8>,
     data: Vec<u8>,
-    inner: Box<dyn WriteSeek>,
+    inner: Option<Either<W, ZlibWriter<W>>>,
+    n_cases: u64,
 }
 
-trait WriteSeek: Write + Seek {}
-impl<T> WriteSeek for T where T: Write + Seek {}
+pub struct WriterInner<'a, W: Write> {
+    case_vars: &'a [CaseVar],
+    opcodes: &'a mut Vec<u8>,
+    data: &'a mut Vec<u8>,
+    inner: &'a mut W,
+}
 
-impl WriterInner {
-    fn finish(mut self) -> Result<(), BinError> {
-        self.flush_compressed()
+impl<'a, W> WriterInner<'a, W>
+where
+    W: Write + Seek,
+{
+    fn new(
+        case_vars: &'a [CaseVar],
+        opcodes: &'a mut Vec<u8>,
+        data: &'a mut Vec<u8>,
+        inner: &'a mut W,
+    ) -> Self {
+        Self {
+            case_vars,
+            opcodes,
+            data,
+            inner,
+        }
     }
-
     fn flush_compressed(&mut self) -> Result<(), BinError> {
         if !self.opcodes.is_empty() {
             self.opcodes.resize(8, 0);
@@ -834,103 +830,52 @@ impl WriterInner {
         self.opcodes.push(opcode);
         Ok(())
     }
-}
-
-impl Writer {
-    fn new<W>(options: WriteOptions, case_vars: Vec<CaseVar>, inner: W) -> Result<Self, BinError>
-    where
-        W: Write + Seek + 'static,
-    {
-        Ok(Self {
-            compression: options.compression,
-            case_vars,
-            inner: WriterInner {
-                opcodes: Vec::with_capacity(8),
-                data: Vec::with_capacity(64),
-                inner: match options.compression {
-                    Some(Compression::ZLib) => Box::new(ZlibWriter::new(inner)?),
-                    _ => Box::new(inner),
-                },
-            },
-        })
-    }
-
-    /// Finishes writing the file, flushing buffers and updating headers to
-    /// match the final case counts.
-    pub fn finish(mut self) -> Result<(), BinError> {
-        self.try_finish()
-    }
-
-    /// Tries to finish writing the file, flushing buffers and updating headers
-    /// to match the final case counts.
-    ///
-    /// # Panic
-    ///
-    /// Attempts to write more cases after calling this function may result in a
-    /// panic.
-    pub fn try_finish(&mut self) -> Result<(), BinError> {
-        self.inner.flush_compressed()
-    }
 
-    /// Writes `case` to the system file.
-    pub fn write_case<'a>(
-        &mut self,
-        case: impl IntoIterator<Item = &'a Datum>,
-    ) -> Result<(), BinError> {
-        match self.compression {
-            Some(_) => self.write_case_compressed(case.into_iter()),
-            None => self.write_case_uncompressed(case.into_iter()),
-        }
-    }
-    fn write_case_uncompressed<'a>(
+    fn write_case_uncompressed<'c>(
         &mut self,
-        case: impl Iterator<Item = &'a Datum>,
+        case: impl Iterator<Item = &'c Datum>,
     ) -> Result<(), BinError> {
-        for (var, datum) in zip_eq(&self.case_vars, case) {
+        for (var, datum) in zip_eq(self.case_vars, case) {
             match var {
                 CaseVar::Numeric => datum
                     .as_number()
                     .unwrap()
                     .unwrap_or(f64::MIN)
-                    .write_le(&mut self.inner.inner)?,
-                CaseVar::String { width: _, encoding } => {
+                    .write_le(&mut self.inner)?,
+                CaseVar::String { encoding } => {
                     let mut s = datum.as_string().unwrap().as_bytes();
                     for segment in encoding {
                         let data;
                         (data, s) = s.split_at(segment.data_bytes);
-                        (data, Pad::new(segment.padding_bytes, 0))
-                            .write_le(&mut self.inner.inner)?;
+                        (data, Pad::new(segment.padding_bytes, 0)).write_le(&mut self.inner)?;
                     }
                 }
             }
         }
         Ok(())
     }
-    fn write_case_compressed<'a>(
+    fn write_case_compressed<'c>(
         &mut self,
-        case: impl Iterator<Item = &'a Datum>,
+        case: impl Iterator<Item = &'c Datum>,
     ) -> Result<(), BinError> {
-        for (var, datum) in zip_eq(&self.case_vars, case) {
+        for (var, datum) in zip_eq(self.case_vars, case) {
             match var {
                 CaseVar::Numeric => match datum.as_number().unwrap() {
-                    None => self.inner.put_opcode(255)?,
+                    None => self.put_opcode(255)?,
                     Some(number) => {
                         if number >= 1.0 - BIAS
                             && number <= 251.0 - BIAS
                             && number == number.trunc()
                         {
-                            self.inner.put_opcode((number + BIAS) as u8)?
+                            self.put_opcode((number + BIAS) as u8)?
                         } else {
-                            self.inner.put_opcode(253)?;
-
-                            number
-                                .write_le(&mut Cursor::new(&mut self.inner.data))
-                                .unwrap();
+                            self.put_opcode(253)?;
+                            self.data.extend_from_slice(&number.to_le_bytes());
                         }
                     }
                 },
 
-                CaseVar::String { width: _, encoding } => {
+                CaseVar::String { encoding } => {
                     let mut s = datum.as_string().unwrap().as_bytes();
                     for segment in encoding {
                         let data;
@@ -939,23 +884,23 @@ impl Writer {
                         let (chunks, remainder) = data.as_chunks::<8>();
                         for chunk in chunks {
                             if chunk == b"        " {
-                                self.inner.put_opcode(254)?;
+                                self.put_opcode(254)?;
                             } else {
-                                self.inner.put_opcode(253)?;
-                                self.inner.data.extend_from_slice(chunk);
+                                self.put_opcode(253)?;
+                                self.data.extend_from_slice(chunk);
                             }
                         }
                         if !remainder.is_empty() {
                             if remainder.iter().all(|c| *c == b' ') {
-                                self.inner.put_opcode(254)?;
+                                self.put_opcode(254)?;
                             } else {
-                                self.inner.put_opcode(253)?;
-                                self.inner.data.extend_from_slice(remainder);
-                                self.inner.data.extend(repeat_n(0, 8 - remainder.len()));
+                                self.put_opcode(253)?;
+                                self.data.extend_from_slice(remainder);
+                                self.data.extend(repeat_n(0, 8 - remainder.len()));
                             }
                         }
                         for _ in 0..segment.padding_bytes / 8 {
-                            self.inner.put_opcode(254)?;
+                            self.put_opcode(254)?;
                         }
                     }
                 }
@@ -965,16 +910,108 @@ impl Writer {
     }
 }
 
-impl Drop for Writer {
+impl<W> Writer<W>
+where
+    W: Write + Seek,
+{
+    fn new(options: WriteOptions, case_vars: Vec<CaseVar>, inner: W) -> Result<Self, BinError> {
+        Ok(Self {
+            compression: options.compression,
+            case_vars,
+            opcodes: Vec::with_capacity(8),
+            data: Vec::with_capacity(64),
+            n_cases: 0,
+            inner: match options.compression {
+                Some(Compression::ZLib) => Some(Either::Right(ZlibWriter::new(inner)?)),
+                _ => Some(Either::Left(inner)),
+            },
+        })
+    }
+
+    /// Finishes writing the file, flushing buffers and updating headers to
+    /// match the final case counts.
+    pub fn finish(mut self) -> Result<(), BinError> {
+        self.try_finish()
+    }
+
+    /// Tries to finish writing the file, flushing buffers and updating headers
+    /// to match the final case counts.
+    ///
+    /// # Panic
+    ///
+    /// Attempts to write more cases after calling this function may will panic.
+    pub fn try_finish(&mut self) -> Result<(), BinError> {
+        let Some(inner) = self.inner.take() else {
+            return Ok(());
+        };
+
+        let mut inner = match inner {
+            Either::Left(mut inner) => {
+                WriterInner::new(
+                    &self.case_vars,
+                    &mut self.opcodes,
+                    &mut self.data,
+                    &mut inner,
+                )
+                .flush_compressed()?;
+                inner
+            }
+            Either::Right(mut zlib_writer) => {
+                WriterInner::new(
+                    &self.case_vars,
+                    &mut self.opcodes,
+                    &mut self.data,
+                    &mut zlib_writer,
+                )
+                .flush_compressed()?;
+                zlib_writer.finish()?
+            }
+        };
+        if let Ok(n_cases) = u32::try_from(self.n_cases) {
+            if inner.seek(SeekFrom::Start(80)).is_ok() {
+                let _ = inner.write_all(&n_cases.to_le_bytes());
+            }
+        }
+        Ok(())
+    }
+
+    /// Writes `case` to the system file.
+    ///
+    /// # Panic
+    ///
+    /// Panics if [try_finish](Self::try_finish) has been called.
+    pub fn write_case<'a>(
+        &mut self,
+        case: impl IntoIterator<Item = &'a Datum>,
+    ) -> Result<(), BinError> {
+        match self.inner.as_mut().unwrap() {
+            Either::Left(inner) => {
+                let mut inner =
+                    WriterInner::new(&self.case_vars, &mut self.opcodes, &mut self.data, inner);
+                match self.compression {
+                    Some(_) => inner.write_case_compressed(case.into_iter())?,
+                    None => inner.write_case_uncompressed(case.into_iter())?,
+                }
+            }
+            Either::Right(inner) => {
+                WriterInner::new(&self.case_vars, &mut self.opcodes, &mut self.data, inner)
+                    .write_case_compressed(case.into_iter())?
+            }
+        }
+        self.n_cases += 1;
+        Ok(())
+    }
+}
+
+impl<W> Drop for Writer<W>
+where
+    W: Write + Seek,
+{
     fn drop(&mut self) {
         let _ = self.try_finish();
     }
 }
 
-struct Block {
-    uncompressed_size: u64,
-    compressed_size: u64,
-}
 struct ZlibWriter<W>
 where
     W: Write + Seek,
@@ -1033,7 +1070,7 @@ where
         Ok(())
     }
 
-    fn try_finish(&mut self) -> Result<(), BinError> {
+    fn finish(mut self) -> Result<W, BinError> {
         self.flush_block()?;
         let ztrailer_offset = self.inner.stream_position()?;
         self.trailer.write_le(&mut self.inner)?;
@@ -1043,20 +1080,8 @@ where
             ztrailer_len: self.trailer.len() as u64,
         };
         self.inner.seek(SeekFrom::Start(header.zheader_offset))?;
-        header.write_le(&mut self.inner)
-    }
-
-    fn finish(mut self) -> Result<(), BinError> {
-        self.try_finish()
-    }
-}
-
-impl<W> Drop for ZlibWriter<W>
-where
-    W: Write + Seek,
-{
-    fn drop(&mut self) {
-        let _ = self.try_finish();
+        header.write_le(&mut self.inner)?;
+        Ok(self.inner)
     }
 }