work
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 9 Dec 2023 21:07:41 +0000 (13:07 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 9 Dec 2023 21:07:41 +0000 (13:07 -0800)
rust/src/cooked.rs
rust/src/main.rs
rust/src/raw.rs

index d00f3f3c34f7a6087570c6315e98b0539f82c2ab..0206b84b75fe5fc2e23209925b1b55fbbd215df2 100644 (file)
@@ -355,7 +355,7 @@ pub fn decode(
             raw::Record::EndOfHeaders(_) => (),
             raw::Record::ZHeader(_) => (),
             raw::Record::ZTrailer(_) => (),
-            raw::Record::Case(_) => (),
+            raw::Record::Cases(_) => (),
         };
     }
     Ok(output)
index 213b381a6e670092cd966c586d2a8006528e504f..45d0622f0d4e039f09c4973596d2d4a5b1b9078c 100644 (file)
@@ -95,25 +95,28 @@ fn dissect(file_name: &Path, max_cases: u64, mode: Mode, encoding: Option<&'stat
             return Ok(())
         }
         Mode::Raw => {
-            let headers: Vec<Record> = reader.collect_headers()?;
-            for header in headers {
-                println!("{header:?}");
+            for header in reader {
+                let header = header?;
+                println!("{:?}", header);
+                if let Record::Cases(cases) = header {
+                    let mut cases = cases.borrow_mut();
+                    for _ in 0..max_cases {
+                        let Some(Ok(record)) = cases.next() else {
+                            break;
+                        };
+                        println!("{:?}", record);
+                    }
+                }
             }
         }
         Mode::Cooked => {
-            let headers: Vec<Record> = reader.collect_headers()?;
-            let headers = decode(headers, encoding, &|e| panic!("{e}"))?;
+            let headers: Vec<Record> = reader.collect::<Result<Vec<_>, _>>()?;
+            let headers = decode(headers, encoding, &|e| eprintln!("{e}"))?;
             for header in headers {
                 println!("{header:?}");
             }
         }
     }
 
-    for _ in 0..max_cases {
-        let Some(Ok(record)) = reader.next() else {
-            break;
-        };
-        println!("{:?}", record);
-    }
     Ok(())
 }
index 222a39b01dd7b9071e93d972b552222c85ae1d18..8b69f760d6364299cc19271f23efb9570bfd8975 100644 (file)
@@ -3,20 +3,19 @@ use crate::endian::{Endian, Parse, ToBytes};
 use encoding_rs::mem::decode_latin1;
 use flate2::read::ZlibDecoder;
 use num::Integer;
-use std::borrow::Cow;
-use std::cmp::Ordering;
-use std::fmt::{Debug, Formatter, Result as FmtResult};
-use std::ops::Range;
-use std::str::from_utf8;
 use std::{
+    borrow::Cow,
+    cmp::Ordering,
     collections::VecDeque,
+    fmt::{Debug, Formatter, Result as FmtResult},
     io::{Error as IoError, Read, Seek, SeekFrom},
-    iter::FusedIterator,
+    mem::take,
+    ops::Range,
+    rc::Rc,
+    str::from_utf8, cell::RefCell,
 };
 use thiserror::Error as ThisError;
 
-use self::state::State;
-
 #[derive(ThisError, Debug)]
 pub enum Error {
     #[error("Not an SPSS system file")]
@@ -158,26 +157,27 @@ pub enum Record {
     EndOfHeaders(u32),
     ZHeader(ZHeader),
     ZTrailer(ZTrailer),
-    Case(Vec<Value>),
+    Cases(Rc<RefCell<Cases>>),
 }
 
 impl Record {
-    fn read<R: Read + Seek>(reader: &mut R, endian: Endian, warn: &Box<dyn Fn(Error)>) -> Result<Record, Error> {
-        loop {
-            if let Some(record) = Self::_read(reader, endian, warn)? {
-                return Ok(record);
-            }
-        }
-    }
-
-    fn _read<R: Read + Seek>(reader: &mut R, endian: Endian, warn: &Box<dyn Fn(Error)>) -> Result<Option<Record>, Error> {
+    fn read<R>(
+        reader: &mut R,
+        endian: Endian,
+        warn: &Box<dyn Fn(Error)>,
+    ) -> Result<Option<Record>, Error>
+    where
+        R: Read + Seek,
+    {
         let rec_type: u32 = endian.parse(read_bytes(reader)?);
         match rec_type {
             2 => Ok(Some(VariableRecord::read(reader, endian)?)),
             3 => Ok(Some(ValueLabelRecord::read(reader, endian)?)),
             6 => Ok(Some(DocumentRecord::read(reader, endian)?)),
             7 => Extension::read(reader, endian, warn),
-            999 => Ok(Some(Record::EndOfHeaders(endian.parse(read_bytes(reader)?)))),
+            999 => Ok(Some(Record::EndOfHeaders(
+                endian.parse(read_bytes(reader)?),
+            ))),
             _ => Err(Error::BadRecordType {
                 offset: reader.stream_position()?,
                 rec_type,
@@ -192,7 +192,7 @@ fn default_decode(s: &[u8]) -> Cow<str> {
     from_utf8(s).map_or_else(|_| decode_latin1(s), Cow::from)
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub enum Compression {
     Simple,
     ZLib,
@@ -398,154 +398,6 @@ impl VarType {
     }
 }
 
-mod state {
-    use super::{
-        Compression, Error, HeaderRecord, Record, Value, VarType, VariableRecord, ZHeader,
-        ZTrailer, ZlibDecodeMultiple,
-    };
-    use crate::endian::Endian;
-    use std::{
-        collections::VecDeque,
-        io::{Read, Seek},
-    };
-
-    pub trait State {
-        #[allow(clippy::type_complexity)]
-        fn read(self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error>;
-    }
-
-    struct Start<R: Read + Seek> {
-        reader: R,
-        warn: Box<dyn Fn(Error)>
-    }
-
-    pub fn new<R: Read + Seek + 'static, F: Fn(Error) + 'static >(reader: R, warn: F) -> Box<dyn State> {
-        Box::new(Start { reader, warn: Box::new(warn) })
-    }
-
-    struct CommonState<R: Read + Seek> {
-        reader: R,
-        warn: Box<dyn Fn(Error)>,
-        endian: Endian,
-        bias: f64,
-        compression: Option<Compression>,
-        var_types: Vec<VarType>,
-    }
-
-    impl<R: Read + Seek + 'static> State for Start<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            let header = HeaderRecord::read(&mut self.reader)?;
-            let next_state = Headers(CommonState {
-                reader: self.reader,
-                warn: self.warn,
-                endian: header.endian,
-                bias: header.bias,
-                compression: header.compression,
-                var_types: Vec::new(),
-            });
-            Ok(Some((Record::Header(header), Box::new(next_state))))
-        }
-    }
-
-    struct Headers<R: Read + Seek>(CommonState<R>);
-
-    impl<R: Read + Seek + 'static> State for Headers<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            let record = Record::read(&mut self.0.reader, self.0.endian, &self.0.warn)?;
-            match record {
-                Record::Variable(VariableRecord { width, .. }) => {
-                    self.0.var_types.push(VarType::from_width(width));
-                }
-                Record::EndOfHeaders(_) => {
-                    let next_state: Box<dyn State> = match self.0.compression {
-                        None => Box::new(Data(self.0)),
-                        Some(Compression::Simple) => Box::new(CompressedData::new(self.0)),
-                        Some(Compression::ZLib) => Box::new(ZlibHeader(self.0)),
-                    };
-                    return Ok(Some((record, next_state)));
-                }
-                _ => (),
-            };
-            Ok(Some((record, self)))
-        }
-    }
-
-    struct ZlibHeader<R: Read + Seek>(CommonState<R>);
-
-    impl<R: Read + Seek + 'static> State for ZlibHeader<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            let zheader = ZHeader::read(&mut self.0.reader, self.0.endian)?;
-            let next_state = Box::new(ZlibTrailer(self.0, zheader.clone()));
-            Ok(Some((Record::ZHeader(zheader), next_state)))
-        }
-    }
-
-    struct ZlibTrailer<R: Read + Seek>(CommonState<R>, ZHeader);
-
-    impl<R: Read + Seek + 'static> State for ZlibTrailer<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            let retval = ZTrailer::read(
-                &mut self.0.reader,
-                self.0.endian,
-                self.1.ztrailer_offset,
-                self.1.ztrailer_len,
-            )?;
-            let next_state = Box::new(CompressedData::new(CommonState {
-                reader: ZlibDecodeMultiple::new(self.0.reader),
-                warn: self.0.warn,
-                endian: self.0.endian,
-                bias: self.0.bias,
-                compression: self.0.compression,
-                var_types: self.0.var_types,
-            }));
-            match retval {
-                None => next_state.read(),
-                Some(ztrailer) => Ok(Some((Record::ZTrailer(ztrailer), next_state))),
-            }
-        }
-    }
-
-    struct Data<R: Read + Seek>(CommonState<R>);
-
-    impl<R: Read + Seek + 'static> State for Data<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            match Value::read_case(&mut self.0.reader, &self.0.var_types, self.0.endian)? {
-                None => Ok(None),
-                Some(values) => Ok(Some((Record::Case(values), self))),
-            }
-        }
-    }
-
-    struct CompressedData<R: Read + Seek> {
-        common: CommonState<R>,
-        codes: VecDeque<u8>,
-    }
-
-    impl<R: Read + Seek + 'static> CompressedData<R> {
-        fn new(common: CommonState<R>) -> CompressedData<R> {
-            CompressedData {
-                common,
-                codes: VecDeque::new(),
-            }
-        }
-    }
-
-    impl<R: Read + Seek + 'static> State for CompressedData<R> {
-        fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-            match Value::read_compressed_case(
-                &mut self.common.reader,
-                &self.common.var_types,
-                &mut self.codes,
-                self.common.endian,
-                self.common.bias,
-            )? {
-                None => Ok(None),
-                Some(values) => Ok(Some((Record::Case(values), self))),
-            }
-        }
-    }
-}
-
 #[derive(Copy, Clone)]
 pub enum Value {
     Number(Option<f64>),
@@ -724,44 +576,194 @@ where
     }
 }
 
-pub struct Reader {
-    state: Option<Box<dyn State>>,
+enum ReaderState {
+    Start,
+    Headers,
+    ZlibHeader,
+    ZlibTrailer {
+        ztrailer_offset: u64,
+        ztrailer_len: u64,
+    },
+    Cases,
+    End,
+}
+
+pub struct Reader<R>
+where
+    R: Read + Seek + 'static,
+{
+    reader: Option<R>,
+    warn: Box<dyn Fn(Error)>,
+
+    header: HeaderRecord,
+    var_types: Vec<VarType>,
+
+    state: ReaderState,
 }
 
-impl Reader {
-    pub fn new<R: Read + Seek + 'static, F: Fn(Error) + 'static>(reader: R, warn: F) -> Result<Self, Error> {
-        Ok(Reader {
-            state: Some(state::new(reader, warn)),
+impl<R> Reader<R>
+where
+    R: Read + Seek + 'static,
+{
+    pub fn new<F>(mut reader: R, warn: F) -> Result<Self, Error>
+    where
+        F: Fn(Error) + 'static,
+    {
+        let header = HeaderRecord::read(&mut reader)?;
+        Ok(Self {
+            reader: Some(reader),
+            warn: Box::new(warn),
+            header,
+            var_types: Vec::new(),
+            state: ReaderState::Start,
         })
     }
-    pub fn collect_headers(&mut self) -> Result<Vec<Record>, Error> {
-        let mut headers = Vec::new();
-        for record in self {
-            match record? {
-                Record::EndOfHeaders(_) => break,
-                r => headers.push(r),
-            };
-        }
-        Ok(headers)
+    fn cases(&mut self) -> Cases {
+        self.state = ReaderState::End;
+        Cases::new(
+            self.reader.take().unwrap(),
+            take(&mut self.var_types),
+            &self.header,
+        )
     }
 }
 
-impl Iterator for Reader {
+impl<R> Iterator for Reader<R>
+where
+    R: Read + Seek + 'static,
+{
     type Item = Result<Record, Error>;
 
     fn next(&mut self) -> Option<Self::Item> {
-        match self.state.take()?.read() {
-            Ok(Some((record, next_state))) => {
-                self.state = Some(next_state);
+        match self.state {
+            ReaderState::Start => {
+                self.state = ReaderState::Headers;
+                Some(Ok(Record::Header(self.header.clone())))
+            }
+            ReaderState::Headers => {
+                let record = loop {
+                    match Record::read(
+                        self.reader.as_mut().unwrap(),
+                        self.header.endian,
+                        &self.warn,
+                    ) {
+                        Ok(Some(record)) => break record,
+                        Ok(None) => (),
+                        Err(error) => return Some(Err(error)),
+                    }
+                };
+                match record {
+                    Record::Variable(VariableRecord { width, .. }) => {
+                        self.var_types.push(VarType::from_width(width));
+                    }
+                    Record::EndOfHeaders(_) => {
+                        self.state = if let Some(Compression::ZLib) = self.header.compression {
+                            ReaderState::ZlibHeader
+                        } else {
+                            ReaderState::Cases
+                        };
+                    }
+                    _ => (),
+                };
                 Some(Ok(record))
             }
-            Ok(None) => None,
-            Err(error) => Some(Err(error)),
+            ReaderState::ZlibHeader => {
+                let zheader = match ZHeader::read(self.reader.as_mut().unwrap(), self.header.endian)
+                {
+                    Ok(zheader) => zheader,
+                    Err(error) => return Some(Err(error)),
+                };
+                self.state = ReaderState::ZlibTrailer {
+                    ztrailer_offset: zheader.ztrailer_offset,
+                    ztrailer_len: zheader.ztrailer_len,
+                };
+                Some(Ok(Record::ZHeader(zheader)))
+            }
+            ReaderState::ZlibTrailer {
+                ztrailer_offset,
+                ztrailer_len,
+            } => {
+                match ZTrailer::read(
+                    self.reader.as_mut().unwrap(),
+                    self.header.endian,
+                    ztrailer_offset,
+                    ztrailer_len,
+                ) {
+                    Ok(None) => Some(Ok(Record::Cases(Rc::new(RefCell::new(self.cases()))))),
+                    Ok(Some(ztrailer)) => Some(Ok(Record::ZTrailer(ztrailer))),
+                    Err(error) => Some(Err(error)),
+                }
+            }
+            ReaderState::Cases => Some(Ok(Record::Cases(Rc::new(RefCell::new(self.cases()))))),
+            ReaderState::End => None,
+        }
+    }
+}
+
+trait ReadSeek: Read + Seek {}
+impl<T> ReadSeek for T where T: Read + Seek {}
+
+pub struct Cases {
+    reader: Box<dyn ReadSeek>,
+    var_types: Vec<VarType>,
+    compression: Option<Compression>,
+    bias: f64,
+    endian: Endian,
+    codes: VecDeque<u8>,
+    eof: bool
+}
+
+impl Debug for Cases {
+    fn fmt(&self, f: &mut Formatter) -> FmtResult {
+        write!(f, "Cases")
+    }
+}
+
+impl Cases {
+    fn new<R>(reader: R, var_types: Vec<VarType>, header: &HeaderRecord) -> Self
+    where
+        R: Read + Seek + 'static,
+    {
+        Self {
+            reader: if header.compression == Some(Compression::ZLib) {
+                Box::new(ZlibDecodeMultiple::new(reader))
+            } else {
+                Box::new(reader)
+            },
+            var_types,
+            compression: header.compression,
+            bias: header.bias,
+            endian: header.endian,
+            codes: VecDeque::with_capacity(8),
+            eof: false,
         }
     }
 }
 
-impl FusedIterator for Reader {}
+impl Iterator for Cases {
+    type Item = Result<Vec<Value>, Error>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.eof {
+            return None;
+        }
+
+        let retval = if self.compression.is_some() {
+            Value::read_compressed_case(
+                &mut self.reader,
+                &self.var_types,
+                &mut self.codes,
+                self.endian,
+                self.bias,
+            )
+            .transpose()
+        } else {
+            Value::read_case(&mut self.reader, &self.var_types, self.endian).transpose()
+        };
+        self.eof = matches!(retval, None | Some(Err(_)));
+        retval
+    }
+}
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash)]
 pub struct Spec(pub u32);
@@ -1589,7 +1591,11 @@ impl Extension {
         Ok(())
     }
 
-    fn read<R: Read + Seek>(r: &mut R, endian: Endian, warn: &Box<dyn Fn(Error)>) -> Result<Option<Record>, Error> {
+    fn read<R: Read + Seek>(
+        r: &mut R,
+        endian: Endian,
+        warn: &Box<dyn Fn(Error)>,
+    ) -> Result<Option<Record>, Error> {
         let subtype = endian.parse(read_bytes(r)?);
         let header_offset = r.stream_position()?;
         let size: u32 = endian.parse(read_bytes(r)?);
@@ -1637,7 +1643,7 @@ impl Extension {
             Err(error) => {
                 warn(error);
                 Ok(None)
-            },
+            }
         }
     }
 }