work
authorBen Pfaff <blp@cs.stanford.edu>
Mon, 21 Jul 2025 23:17:43 +0000 (16:17 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Mon, 21 Jul 2025 23:17:43 +0000 (16:17 -0700)
rust/pspp/src/sys/cooked.rs
rust/pspp/src/sys/raw/records.rs
rust/pspp/src/sys/write.rs

index 2651f5fe5ee4ba100a66a9fbd9f69ddb8e2793a4..c62f49964d873dbcaa40e583ccf353db0ff76ca4 100644 (file)
@@ -1495,7 +1495,7 @@ impl Metadata {
             n_cases: headers
                 .number_of_cases
                 .first()
-                .map(|record| record.n_cases)
+                .and_then(|record| record.n_cases)
                 .or_else(|| header.n_cases.map(|n| n as u64)),
             product,
             product_ext: headers.product_info.first().map(|pe| fix_line_ends(&pe.0)),
index 714f7b4af4c850dfab0865d4eaa7b02769162f0e..e85a473800cb7f0e50596e456e28572e824daf8f 100644 (file)
@@ -174,7 +174,7 @@ impl FileHeader<RawString> {
 
         let weight_index = (header.weight_index > 0).then_some(header.weight_index);
 
-        let n_cases = (header.n_cases < i32::MAX as u32 / 2).then_some(header.n_cases);
+        let n_cases = (header.n_cases <= u32::MAX / 2).then_some(header.n_cases);
 
         if header.bias != 100.0 && header.bias != 0.0 {
             warn(Warning::new(
@@ -1618,7 +1618,7 @@ pub struct NumberOfCasesRecord {
     pub one: u64,
 
     /// Number of cases.
-    pub n_cases: u64,
+    pub n_cases: Option<u64>,
 }
 
 impl NumberOfCasesRecord {
@@ -1629,6 +1629,7 @@ impl NumberOfCasesRecord {
         let mut input = &ext.data[..];
         let one = endian.parse(read_bytes(&mut input)?);
         let n_cases = endian.parse(read_bytes(&mut input)?);
+        let n_cases = (n_cases < u64::MAX).then_some(n_cases);
 
         Ok(Record::NumberOfCases(NumberOfCasesRecord { one, n_cases }))
     }
@@ -2677,7 +2678,6 @@ impl ZTrailer {
     where
         R: Read + Seek,
     {
-        dbg!();
         let start_offset = reader.stream_position()?;
         if reader
             .seek(SeekFrom::Start(zheader.ztrailer_offset))
index 0f3ba5589695aeb21cac1089ca5d95f3e5a38e4f..369bea3f25fc8a8fc362992e1e1233d96d239e02 100644 (file)
@@ -13,12 +13,12 @@ use std::{
 use binrw::{BinWrite, Endian, Error as BinError};
 use chrono::Local;
 use encoding_rs::Encoding;
-use flate2::{Compress, FlushCompress, Status};
+use flate2::write::ZlibEncoder;
 use itertools::zip_eq;
 use smallvec::SmallVec;
 
 use crate::{
-    data::{Case, Datum},
+    data::Datum,
     dictionary::{
         Alignment, Attributes, CategoryLabels, Dictionary, Measure, MultipleResponseType,
         ValueLabels, VarWidth,
@@ -813,6 +813,10 @@ trait WriteSeek: Write + Seek {}
 impl<T> WriteSeek for T where T: Write + Seek {}
 
 impl WriterInner {
+    fn finish(mut self) -> Result<(), BinError> {
+        self.flush_compressed()
+    }
+
     fn flush_compressed(&mut self) -> Result<(), BinError> {
         if !self.opcodes.is_empty() {
             self.opcodes.resize(8, 0);
@@ -850,18 +854,39 @@ impl Writer {
             },
         })
     }
-}
 
-impl Writer {
+    /// Finishes writing the file, flushing buffers and updating headers to
+    /// match the final case counts.
+    pub fn finish(mut self) -> Result<(), BinError> {
+        self.try_finish()
+    }
+
+    /// Tries to finish writing the file, flushing buffers and updating headers
+    /// to match the final case counts.
+    ///
+    /// # Panic
+    ///
+    /// Attempts to write more cases after calling this function may result in a
+    /// panic.
+    pub fn try_finish(&mut self) -> Result<(), BinError> {
+        self.inner.flush_compressed()
+    }
+
     /// Writes `case` to the system file.
-    pub fn write_case(&mut self, case: &Case) -> Result<(), BinError> {
+    pub fn write_case<'a>(
+        &mut self,
+        case: impl IntoIterator<Item = &'a Datum>,
+    ) -> Result<(), BinError> {
         match self.compression {
-            Some(_) => self.write_case_compressed(case),
-            None => self.write_case_uncompressed(case),
+            Some(_) => self.write_case_compressed(case.into_iter()),
+            None => self.write_case_uncompressed(case.into_iter()),
         }
     }
-    fn write_case_uncompressed(&mut self, case: &Case) -> Result<(), BinError> {
-        for (var, datum) in zip_eq(&self.case_vars, &case.0) {
+    fn write_case_uncompressed<'a>(
+        &mut self,
+        case: impl Iterator<Item = &'a Datum>,
+    ) -> Result<(), BinError> {
+        for (var, datum) in zip_eq(&self.case_vars, case) {
             match var {
                 CaseVar::Numeric => datum
                     .as_number()
@@ -881,8 +906,11 @@ impl Writer {
         }
         Ok(())
     }
-    fn write_case_compressed(&mut self, case: &Case) -> Result<(), BinError> {
-        for (var, datum) in zip_eq(&self.case_vars, &case.0) {
+    fn write_case_compressed<'a>(
+        &mut self,
+        case: impl Iterator<Item = &'a Datum>,
+    ) -> Result<(), BinError> {
+        for (var, datum) in zip_eq(&self.case_vars, case) {
             match var {
                 CaseVar::Numeric => match datum.as_number().unwrap() {
                     None => self.inner.put_opcode(255)?,
@@ -937,6 +965,12 @@ impl Writer {
     }
 }
 
+impl Drop for Writer {
+    fn drop(&mut self) {
+        let _ = self.try_finish();
+    }
+}
+
 struct Block {
     uncompressed_size: u64,
     compressed_size: u64,
@@ -947,8 +981,7 @@ where
 {
     header: RawZHeader,
     trailer: RawZTrailer,
-    compress: Compress,
-    buf: Vec<u8>,
+    encoder: ZlibEncoder<Vec<u8>>,
     inner: W,
 }
 
@@ -971,13 +1004,37 @@ where
                 block_size: ZBLOCK_SIZE as u32,
                 blocks: Vec::new(),
             },
-            compress: Compress::new(flate2::Compression::new(5), false),
-            buf: Vec::with_capacity(4096),
+            encoder: ZlibEncoder::new(Vec::new(), flate2::Compression::new(1)),
             inner,
         })
     }
+
+    fn flush_block(&mut self) -> std::io::Result<()> {
+        let total_in = self.encoder.total_in();
+        if total_in > 0 {
+            let buf = self.encoder.reset(Vec::new())?;
+            let total_out = buf.len();
+            self.inner.write_all(&buf)?;
+            self.encoder.reset(buf).unwrap();
+
+            self.trailer.blocks.push(ZBlock {
+                uncompressed_size: total_in as u32,
+                compressed_size: total_out as u32,
+                uncompressed_ofs: match self.trailer.blocks.last() {
+                    Some(prev) => prev.uncompressed_ofs + prev.uncompressed_size as u64,
+                    None => self.header.zheader_offset,
+                },
+                compressed_ofs: match self.trailer.blocks.last() {
+                    Some(prev) => prev.compressed_ofs + prev.compressed_size as u64,
+                    None => self.header.zheader_offset + 24,
+                },
+            });
+        }
+        Ok(())
+    }
+
     fn try_finish(&mut self) -> Result<(), BinError> {
-        self.flush()?;
+        self.flush_block()?;
         let ztrailer_offset = self.inner.stream_position()?;
         self.trailer.write_le(&mut self.inner)?;
         let header = RawZHeader {
@@ -985,10 +1042,10 @@ where
             ztrailer_offset,
             ztrailer_len: self.trailer.len() as u64,
         };
-        dbg!(&header);
         self.inner.seek(SeekFrom::Start(header.zheader_offset))?;
         header.write_le(&mut self.inner)
     }
+
     fn finish(mut self) -> Result<(), BinError> {
         self.try_finish()
     }
@@ -999,7 +1056,6 @@ where
     W: Write + Seek,
 {
     fn drop(&mut self) {
-        dbg!();
         let _ = self.try_finish();
     }
 }
@@ -1013,52 +1069,20 @@ where
     fn write(&mut self, mut buf: &[u8]) -> Result<usize, IoError> {
         let n = buf.len();
         while buf.len() > 0 {
-            if self.compress.total_in() >= ZBLOCK_SIZE {
-                self.flush()?;
+            if self.encoder.total_in() >= ZBLOCK_SIZE {
+                self.flush_block()?;
             }
 
             let chunk = buf
                 .len()
-                .min((ZBLOCK_SIZE - self.compress.total_in()) as usize);
-            let in_before = self.compress.total_in();
-            self.buf.clear();
-            self.compress
-                .compress_vec(&buf[..chunk], &mut self.buf, FlushCompress::None)
-                .unwrap();
-            let consumed = self.compress.total_in() - in_before;
-            self.inner.write_all(&self.buf)?;
-            buf = &buf[consumed as usize..];
+                .min((ZBLOCK_SIZE - self.encoder.total_in()) as usize);
+            self.encoder.write_all(&buf[..chunk])?;
+            buf = &buf[chunk..];
         }
         Ok(n)
     }
 
     fn flush(&mut self) -> std::io::Result<()> {
-        if self.compress.total_in() > 0 {
-            let mut status = Status::Ok;
-            while status == Status::Ok {
-                self.buf.clear();
-                status = self
-                    .compress
-                    .compress_vec(&[], &mut self.buf, FlushCompress::Finish)
-                    .unwrap();
-                self.inner.write_all(&self.buf)?;
-            }
-            assert_eq!(status, Status::StreamEnd);
-
-            self.trailer.blocks.push(ZBlock {
-                uncompressed_size: self.compress.total_in() as u32,
-                compressed_size: self.compress.total_out() as u32,
-                uncompressed_ofs: match self.trailer.blocks.last() {
-                    Some(prev) => prev.uncompressed_ofs + prev.uncompressed_size as u64,
-                    None => self.header.zheader_offset,
-                },
-                compressed_ofs: match self.trailer.blocks.last() {
-                    Some(prev) => prev.compressed_ofs + prev.compressed_size as u64,
-                    None => self.header.zheader_offset + 24,
-                },
-            });
-            self.compress.reset();
-        }
         Ok(())
     }
 }
@@ -1068,7 +1092,6 @@ where
     W: Write + Seek,
 {
     fn seek(&mut self, _pos: std::io::SeekFrom) -> Result<u64, IoError> {
-        panic!();
         Err(IoError::from(ErrorKind::NotSeekable))
     }
 }