From: Ben Pfaff Date: Mon, 21 Jul 2025 23:17:43 +0000 (-0700) Subject: work X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8f28690dd78db853d20152519073c5743f470aea;p=pspp work --- diff --git a/rust/pspp/src/sys/cooked.rs b/rust/pspp/src/sys/cooked.rs index 2651f5fe5e..c62f49964d 100644 --- a/rust/pspp/src/sys/cooked.rs +++ b/rust/pspp/src/sys/cooked.rs @@ -1495,7 +1495,7 @@ impl Metadata { n_cases: headers .number_of_cases .first() - .map(|record| record.n_cases) + .and_then(|record| record.n_cases) .or_else(|| header.n_cases.map(|n| n as u64)), product, product_ext: headers.product_info.first().map(|pe| fix_line_ends(&pe.0)), diff --git a/rust/pspp/src/sys/raw/records.rs b/rust/pspp/src/sys/raw/records.rs index 714f7b4af4..e85a473800 100644 --- a/rust/pspp/src/sys/raw/records.rs +++ b/rust/pspp/src/sys/raw/records.rs @@ -174,7 +174,7 @@ impl FileHeader { let weight_index = (header.weight_index > 0).then_some(header.weight_index); - let n_cases = (header.n_cases < i32::MAX as u32 / 2).then_some(header.n_cases); + let n_cases = (header.n_cases <= u32::MAX / 2).then_some(header.n_cases); if header.bias != 100.0 && header.bias != 0.0 { warn(Warning::new( @@ -1618,7 +1618,7 @@ pub struct NumberOfCasesRecord { pub one: u64, /// Number of cases. - pub n_cases: u64, + pub n_cases: Option, } impl NumberOfCasesRecord { @@ -1629,6 +1629,7 @@ impl NumberOfCasesRecord { let mut input = &ext.data[..]; let one = endian.parse(read_bytes(&mut input)?); let n_cases = endian.parse(read_bytes(&mut input)?); + let n_cases = (n_cases < u64::MAX).then_some(n_cases); Ok(Record::NumberOfCases(NumberOfCasesRecord { one, n_cases })) } @@ -2677,7 +2678,6 @@ impl ZTrailer { where R: Read + Seek, { - dbg!(); let start_offset = reader.stream_position()?; if reader .seek(SeekFrom::Start(zheader.ztrailer_offset)) diff --git a/rust/pspp/src/sys/write.rs b/rust/pspp/src/sys/write.rs index 0f3ba55896..369bea3f25 100644 --- a/rust/pspp/src/sys/write.rs +++ b/rust/pspp/src/sys/write.rs @@ -13,12 +13,12 @@ use std::{ use binrw::{BinWrite, Endian, Error as BinError}; use chrono::Local; use encoding_rs::Encoding; -use flate2::{Compress, FlushCompress, Status}; +use flate2::write::ZlibEncoder; use itertools::zip_eq; use smallvec::SmallVec; use crate::{ - data::{Case, Datum}, + data::Datum, dictionary::{ Alignment, Attributes, CategoryLabels, Dictionary, Measure, MultipleResponseType, ValueLabels, VarWidth, @@ -813,6 +813,10 @@ trait WriteSeek: Write + Seek {} impl WriteSeek for T where T: Write + Seek {} impl WriterInner { + fn finish(mut self) -> Result<(), BinError> { + self.flush_compressed() + } + fn flush_compressed(&mut self) -> Result<(), BinError> { if !self.opcodes.is_empty() { self.opcodes.resize(8, 0); @@ -850,18 +854,39 @@ impl Writer { }, }) } -} -impl Writer { + /// Finishes writing the file, flushing buffers and updating headers to + /// match the final case counts. + pub fn finish(mut self) -> Result<(), BinError> { + self.try_finish() + } + + /// Tries to finish writing the file, flushing buffers and updating headers + /// to match the final case counts. + /// + /// # Panic + /// + /// Attempts to write more cases after calling this function may result in a + /// panic. + pub fn try_finish(&mut self) -> Result<(), BinError> { + self.inner.flush_compressed() + } + /// Writes `case` to the system file. - pub fn write_case(&mut self, case: &Case) -> Result<(), BinError> { + pub fn write_case<'a>( + &mut self, + case: impl IntoIterator, + ) -> Result<(), BinError> { match self.compression { - Some(_) => self.write_case_compressed(case), - None => self.write_case_uncompressed(case), + Some(_) => self.write_case_compressed(case.into_iter()), + None => self.write_case_uncompressed(case.into_iter()), } } - fn write_case_uncompressed(&mut self, case: &Case) -> Result<(), BinError> { - for (var, datum) in zip_eq(&self.case_vars, &case.0) { + fn write_case_uncompressed<'a>( + &mut self, + case: impl Iterator, + ) -> Result<(), BinError> { + for (var, datum) in zip_eq(&self.case_vars, case) { match var { CaseVar::Numeric => datum .as_number() @@ -881,8 +906,11 @@ impl Writer { } Ok(()) } - fn write_case_compressed(&mut self, case: &Case) -> Result<(), BinError> { - for (var, datum) in zip_eq(&self.case_vars, &case.0) { + fn write_case_compressed<'a>( + &mut self, + case: impl Iterator, + ) -> Result<(), BinError> { + for (var, datum) in zip_eq(&self.case_vars, case) { match var { CaseVar::Numeric => match datum.as_number().unwrap() { None => self.inner.put_opcode(255)?, @@ -937,6 +965,12 @@ impl Writer { } } +impl Drop for Writer { + fn drop(&mut self) { + let _ = self.try_finish(); + } +} + struct Block { uncompressed_size: u64, compressed_size: u64, @@ -947,8 +981,7 @@ where { header: RawZHeader, trailer: RawZTrailer, - compress: Compress, - buf: Vec, + encoder: ZlibEncoder>, inner: W, } @@ -971,13 +1004,37 @@ where block_size: ZBLOCK_SIZE as u32, blocks: Vec::new(), }, - compress: Compress::new(flate2::Compression::new(5), false), - buf: Vec::with_capacity(4096), + encoder: ZlibEncoder::new(Vec::new(), flate2::Compression::new(1)), inner, }) } + + fn flush_block(&mut self) -> std::io::Result<()> { + let total_in = self.encoder.total_in(); + if total_in > 0 { + let buf = self.encoder.reset(Vec::new())?; + let total_out = buf.len(); + self.inner.write_all(&buf)?; + self.encoder.reset(buf).unwrap(); + + self.trailer.blocks.push(ZBlock { + uncompressed_size: total_in as u32, + compressed_size: total_out as u32, + uncompressed_ofs: match self.trailer.blocks.last() { + Some(prev) => prev.uncompressed_ofs + prev.uncompressed_size as u64, + None => self.header.zheader_offset, + }, + compressed_ofs: match self.trailer.blocks.last() { + Some(prev) => prev.compressed_ofs + prev.compressed_size as u64, + None => self.header.zheader_offset + 24, + }, + }); + } + Ok(()) + } + fn try_finish(&mut self) -> Result<(), BinError> { - self.flush()?; + self.flush_block()?; let ztrailer_offset = self.inner.stream_position()?; self.trailer.write_le(&mut self.inner)?; let header = RawZHeader { @@ -985,10 +1042,10 @@ where ztrailer_offset, ztrailer_len: self.trailer.len() as u64, }; - dbg!(&header); self.inner.seek(SeekFrom::Start(header.zheader_offset))?; header.write_le(&mut self.inner) } + fn finish(mut self) -> Result<(), BinError> { self.try_finish() } @@ -999,7 +1056,6 @@ where W: Write + Seek, { fn drop(&mut self) { - dbg!(); let _ = self.try_finish(); } } @@ -1013,52 +1069,20 @@ where fn write(&mut self, mut buf: &[u8]) -> Result { let n = buf.len(); while buf.len() > 0 { - if self.compress.total_in() >= ZBLOCK_SIZE { - self.flush()?; + if self.encoder.total_in() >= ZBLOCK_SIZE { + self.flush_block()?; } let chunk = buf .len() - .min((ZBLOCK_SIZE - self.compress.total_in()) as usize); - let in_before = self.compress.total_in(); - self.buf.clear(); - self.compress - .compress_vec(&buf[..chunk], &mut self.buf, FlushCompress::None) - .unwrap(); - let consumed = self.compress.total_in() - in_before; - self.inner.write_all(&self.buf)?; - buf = &buf[consumed as usize..]; + .min((ZBLOCK_SIZE - self.encoder.total_in()) as usize); + self.encoder.write_all(&buf[..chunk])?; + buf = &buf[chunk..]; } Ok(n) } fn flush(&mut self) -> std::io::Result<()> { - if self.compress.total_in() > 0 { - let mut status = Status::Ok; - while status == Status::Ok { - self.buf.clear(); - status = self - .compress - .compress_vec(&[], &mut self.buf, FlushCompress::Finish) - .unwrap(); - self.inner.write_all(&self.buf)?; - } - assert_eq!(status, Status::StreamEnd); - - self.trailer.blocks.push(ZBlock { - uncompressed_size: self.compress.total_in() as u32, - compressed_size: self.compress.total_out() as u32, - uncompressed_ofs: match self.trailer.blocks.last() { - Some(prev) => prev.uncompressed_ofs + prev.uncompressed_size as u64, - None => self.header.zheader_offset, - }, - compressed_ofs: match self.trailer.blocks.last() { - Some(prev) => prev.compressed_ofs + prev.compressed_size as u64, - None => self.header.zheader_offset + 24, - }, - }); - self.compress.reset(); - } Ok(()) } } @@ -1068,7 +1092,6 @@ where W: Write + Seek, { fn seek(&mut self, _pos: std::io::SeekFrom) -> Result { - panic!(); Err(IoError::from(ErrorKind::NotSeekable)) } }