get rid of Rc<RefCell<Cases>>
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 6 Jul 2025 15:37:27 +0000 (08:37 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 6 Jul 2025 15:37:27 +0000 (08:37 -0700)
rust/pspp/src/main.rs
rust/pspp/src/sys/cooked.rs
rust/pspp/src/sys/raw.rs
rust/pspp/src/sys/test.rs

index 16138eef89cb555c63124484742265b8c33ea837..8677e3e01b9de78e19d4241e21d2316b9b8bdf11 100644 (file)
@@ -92,7 +92,7 @@ fn dissect(
 
     match mode {
         Mode::Identify => {
-            let Record::Header(header) = reader.next().unwrap()? else {
+            let Record::Header(header) = reader.headers().next().unwrap()? else {
                 unreachable!()
             };
             match header.magic {
@@ -103,22 +103,18 @@ fn dissect(
             return Ok(());
         }
         Mode::Raw => {
-            for header in reader {
+            for header in reader.headers() {
                 let header = header?;
                 println!("{:?}", header);
-                if let Record::Cases(cases) = header {
-                    let mut cases = cases.borrow_mut();
-                    for _ in 0..max_cases {
-                        let Some(Ok(record)) = cases.next() else {
-                            break;
-                        };
-                        println!("{:?}", record);
-                    }
+            }
+            if let Some(cases) = reader.cases() {
+                for (_index, case) in (0..max_cases).zip(cases) {
+                    println!("{:?}", case?);
                 }
             }
         }
         Mode::Decoded => {
-            let headers: Vec<Record> = reader.collect::<Result<Vec<_>, _>>()?;
+            let headers: Vec<Record> = reader.headers().collect::<Result<Vec<_>, _>>()?;
             let encoding = match encoding {
                 Some(encoding) => encoding,
                 None => encoding_from_headers(&headers, &mut |e| eprintln!("{e}"))?,
@@ -141,7 +137,7 @@ fn dissect(
             }
         }
         Mode::Cooked => {
-            let headers: Vec<Record> = reader.collect::<Result<Vec<_>, _>>()?;
+            let headers: Vec<Record> = reader.headers().collect::<Result<Vec<_>, _>>()?;
             let encoding = match encoding {
                 Some(encoding) => encoding,
                 None => encoding_from_headers(&headers, &mut |e| eprintln!("{e}"))?,
@@ -152,7 +148,8 @@ fn dissect(
                 decoded_records.push(header.decode(&mut decoder)?);
             }
             let headers = Headers::new(decoded_records, &mut |e| eprintln!("{e}"))?;
-            let (dictionary, metadata, _cases) = decode(headers, encoding, |e| eprintln!("{e}"))?;
+            let (dictionary, metadata, _cases) =
+                decode(headers, None, encoding, |e| eprintln!("{e}"))?;
             println!("{dictionary:#?}");
             println!("{metadata:#?}");
         }
index 86ad39edde342f6dd97a47ff89c9793cc32d66b7..a956dd8605e3e55b86c147c7fece64e5e84328ab 100644 (file)
@@ -1,4 +1,4 @@
-use std::{cell::RefCell, collections::BTreeMap, ops::Range, rc::Rc};
+use std::{collections::BTreeMap, ops::Range};
 
 use crate::{
     calendar::date_time_to_pspp,
@@ -313,7 +313,6 @@ pub struct Headers {
     pub end_of_headers: Option<u32>,
     pub z_header: Option<ZHeader>,
     pub z_trailer: Option<ZTrailer>,
-    pub cases: Option<Rc<RefCell<Cases>>>,
 }
 
 fn take_first<T>(
@@ -354,7 +353,6 @@ impl Headers {
         let mut end_of_headers = Vec::new();
         let mut z_header = Vec::new();
         let mut z_trailer = Vec::new();
-        let mut cases = Vec::new();
 
         for header in headers {
             match header {
@@ -424,9 +422,6 @@ impl Headers {
                 DecodedRecord::ZTrailer(record) => {
                     z_trailer.push(record);
                 }
-                DecodedRecord::Cases(record) => {
-                    cases.push(record);
-                }
             }
         }
 
@@ -457,7 +452,6 @@ impl Headers {
             end_of_headers: take_first(end_of_headers, "end of headers", warn),
             z_header: take_first(z_header, "z_header", warn),
             z_trailer: take_first(z_trailer, "z_trailer", warn),
-            cases: take_first(cases, "cases", warn),
         })
     }
 }
@@ -584,9 +578,10 @@ impl Decoder {
 
 pub fn decode(
     mut headers: Headers,
+    cases: Option<Cases>,
     encoding: &'static Encoding,
     mut warn: impl FnMut(Error),
-) -> Result<(Dictionary, Metadata, Rc<RefCell<Cases>>), Error> {
+) -> Result<(Dictionary, Metadata, Option<Cases>), Error> {
     let mut dictionary = Dictionary::new(encoding);
 
     let file_label = fix_line_ends(headers.header.file_label.trim_end_matches(' '));
@@ -1093,7 +1088,7 @@ pub fn decode(
     }
 
     let metadata = Metadata::decode(&headers, warn);
-    Ok((dictionary, metadata, headers.cases.take().unwrap()))
+    Ok((dictionary, metadata, cases))
 }
 
 impl MultipleResponseSet {
index 86f2971af11f897bd6e52907468fa06604c0a0ef..46c201a86520592176b399a633016b9bac8d2d20 100644 (file)
@@ -16,12 +16,11 @@ use std::{
     cell::RefCell,
     collections::{BTreeMap, VecDeque},
     fmt::{Debug, Display, Formatter, Result as FmtResult},
-    io::{Error as IoError, Read, Seek, SeekFrom},
+    io::{empty, Error as IoError, Read, Seek, SeekFrom},
     iter::repeat_n,
     mem::take,
     num::NonZeroU8,
     ops::{Deref, Not, Range},
-    rc::Rc,
     str::from_utf8,
 };
 use thiserror::Error as ThisError;
@@ -312,7 +311,6 @@ pub enum Record {
     EndOfHeaders(u32),
     ZHeader(ZHeader),
     ZTrailer(ZTrailer),
-    Cases(Rc<RefCell<Cases>>),
 }
 
 #[derive(Clone, Debug)]
@@ -339,7 +337,6 @@ pub enum DecodedRecord {
     EndOfHeaders(u32),
     ZHeader(ZHeader),
     ZTrailer(ZTrailer),
-    Cases(Rc<RefCell<Cases>>),
 }
 
 impl Record {
@@ -400,7 +397,6 @@ impl Record {
             Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(record),
             Record::ZHeader(record) => DecodedRecord::ZHeader(record.clone()),
             Record::ZTrailer(record) => DecodedRecord::ZTrailer(record.clone()),
-            Record::Cases(record) => DecodedRecord::Cases(record.clone()),
         })
     }
 }
@@ -968,7 +964,6 @@ enum ReaderState {
         ztrailer_offset: u64,
         ztrailer_len: u64,
     },
-    Cases,
     End,
 }
 
@@ -983,6 +978,7 @@ where
     var_types: VarTypes,
 
     state: ReaderState,
+    cases: Option<Cases>,
 }
 
 impl<'a, R> Reader<'a, R>
@@ -997,29 +993,47 @@ where
             header,
             var_types: VarTypes::new(),
             state: ReaderState::Start,
+            cases: None,
         })
     }
-    fn cases(&mut self) -> Cases {
-        self.state = ReaderState::End;
-        Cases::new(
-            self.reader.take().unwrap(),
-            take(&mut self.var_types),
-            &self.header,
-        )
+    pub fn headers<'b>(&'b mut self) -> ReadHeaders<'a, 'b, R> {
+        ReadHeaders(self)
     }
+    pub fn cases(self) -> Option<Cases> {
+        self.cases
+    }
+}
+
+pub struct ReadHeaders<'a, 'b, R>(&'b mut Reader<'a, R>)
+where
+    R: Read + Seek + 'static;
+
+impl<'a, 'b, R> ReadHeaders<'a, 'b, R>
+where
+    R: Read + Seek + 'static,
+{
+    fn cases(&mut self) {
+        self.0.state = ReaderState::End;
+        self.0.cases = Some(Cases::new(
+            self.0.reader.take().unwrap(),
+            take(&mut self.0.var_types),
+            &self.0.header,
+        ));
+    }
+
     fn _next(&mut self) -> Option<<Self as Iterator>::Item> {
-        match self.state {
+        match self.0.state {
             ReaderState::Start => {
-                self.state = ReaderState::Headers;
-                Some(Ok(Record::Header(self.header.clone())))
+                self.0.state = ReaderState::Headers;
+                Some(Ok(Record::Header(self.0.header.clone())))
             }
             ReaderState::Headers => {
                 let record = loop {
                     match Record::read(
-                        self.reader.as_mut().unwrap(),
-                        self.header.endian,
-                        &self.var_types,
-                        &mut self.warn,
+                        self.0.reader.as_mut().unwrap(),
+                        self.0.header.endian,
+                        &self.0.var_types,
+                        &mut self.0.warn,
                     ) {
                         Ok(Some(record)) => break record,
                         Ok(None) => (),
@@ -1027,12 +1041,13 @@ where
                     }
                 };
                 match record {
-                    Record::Variable(VariableRecord { width, .. }) => self.var_types.push(width),
+                    Record::Variable(VariableRecord { width, .. }) => self.0.var_types.push(width),
                     Record::EndOfHeaders(_) => {
-                        self.state = if let Some(Compression::ZLib) = self.header.compression {
+                        self.0.state = if let Some(Compression::ZLib) = self.0.header.compression {
                             ReaderState::ZlibHeader
                         } else {
-                            ReaderState::Cases
+                            self.cases();
+                            ReaderState::End
                         };
                     }
                     _ => (),
@@ -1040,12 +1055,12 @@ where
                 Some(Ok(record))
             }
             ReaderState::ZlibHeader => {
-                let zheader = match ZHeader::read(self.reader.as_mut().unwrap(), self.header.endian)
-                {
-                    Ok(zheader) => zheader,
-                    Err(error) => return Some(Err(error)),
-                };
-                self.state = ReaderState::ZlibTrailer {
+                let zheader =
+                    match ZHeader::read(self.0.reader.as_mut().unwrap(), self.0.header.endian) {
+                        Ok(zheader) => zheader,
+                        Err(error) => return Some(Err(error)),
+                    };
+                self.0.state = ReaderState::ZlibTrailer {
                     ztrailer_offset: zheader.ztrailer_offset,
                     ztrailer_len: zheader.ztrailer_len,
                 };
@@ -1056,26 +1071,28 @@ where
                 ztrailer_len,
             } => {
                 match ZTrailer::read(
-                    self.reader.as_mut().unwrap(),
-                    self.header.endian,
+                    self.0.reader.as_mut().unwrap(),
+                    self.0.header.endian,
                     ztrailer_offset,
                     ztrailer_len,
                 ) {
-                    Ok(None) => Some(Ok(Record::Cases(Rc::new(RefCell::new(self.cases()))))),
+                    Ok(None) => {
+                        self.cases();
+                        None
+                    }
                     Ok(Some(ztrailer)) => {
-                        self.state = ReaderState::Cases;
+                        self.cases();
                         Some(Ok(Record::ZTrailer(ztrailer)))
                     }
                     Err(error) => Some(Err(error)),
                 }
             }
-            ReaderState::Cases => Some(Ok(Record::Cases(Rc::new(RefCell::new(self.cases()))))),
             ReaderState::End => None,
         }
     }
 }
 
-impl<'a, R> Iterator for Reader<'a, R>
+impl<'a, 'b, R> Iterator for ReadHeaders<'a, 'b, R>
 where
     R: Read + Seek + 'static,
 {
@@ -1084,7 +1101,7 @@ where
     fn next(&mut self) -> Option<Self::Item> {
         let retval = self._next();
         if matches!(retval, Some(Err(_))) {
-            self.state = ReaderState::End;
+            self.0.state = ReaderState::End;
         }
         retval
     }
@@ -1172,6 +1189,20 @@ impl Debug for Cases {
     }
 }
 
+impl Default for Cases {
+    fn default() -> Self {
+        Self {
+            reader: Box::new(empty()),
+            case_vars: Vec::new(),
+            compression: None,
+            bias: 100.0,
+            endian: Endian::Little,
+            codes: VecDeque::new(),
+            eof: true,
+        }
+    }
+}
+
 impl Cases {
     fn new<R>(reader: R, var_types: VarTypes, header: &HeaderRecord<RawString>) -> Self
     where
index 8004947c68e5a5d3bad715803e54f058bd499111..f3a5ade4c9520cbeb1e253f0be1e4d5b891ff4c5 100644 (file)
@@ -530,9 +530,10 @@ fn test_sysfile(name: &str) {
         let sysfile = sack(&input, Some(&input_filename), endian).unwrap();
         let cursor = Cursor::new(sysfile);
         let mut warnings = Vec::new();
-        let reader = Reader::new(cursor, |warning| warnings.push(warning)).unwrap();
-        let output = match reader.collect() {
+        let mut reader = Reader::new(cursor, |warning| warnings.push(warning)).unwrap();
+        let output = match reader.headers().collect() {
             Ok(headers) => {
+                drop(reader);
                 let encoding =
                     encoding_from_headers(&headers, &mut |warning| warnings.push(warning)).unwrap();
                 let mut decoder = Decoder::new(encoding, |warning| warnings.push(warning));
@@ -544,8 +545,8 @@ fn test_sysfile(name: &str) {
 
                 let mut errors = Vec::new();
                 let headers = Headers::new(decoded_records, &mut |e| errors.push(e)).unwrap();
-                let (dictionary, metadata, cases) =
-                    decode(headers, encoding, |e| errors.push(e)).unwrap();
+                let (dictionary, metadata, _cases) =
+                    decode(headers, None, encoding, |e| errors.push(e)).unwrap();
                 let (group, data) = metadata.to_pivot_rows();
                 let metadata_table = PivotTable::new([(Axis3::Y, Dimension::new(group))])
                     .with_data(