put common state into struct
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 29 Jul 2023 04:28:36 +0000 (21:28 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 29 Jul 2023 04:28:36 +0000 (21:28 -0700)
rust/src/lib.rs

index 49887a35659de18830db8dde19d1437605c34e6f..504d5eb3a645af70428fe144f5bb24567239a13a 100644 (file)
@@ -174,7 +174,7 @@ pub struct Header {
     pub file_label: [u8; 64],
 
     /// Endianness of the data in the file header.
-    pub endianness: Endian,
+    pub endian: Endian,
 }
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash)]
@@ -227,59 +227,62 @@ struct Start<R: Read + Seek> {
     reader: R,
 }
 
+struct CommonState<R: Read + Seek> {
+    reader: R,
+    endian: Endian,
+    bias: f64,
+    compression: Option<Compression>,
+    var_types: Vec<VarType>,
+}
+
 impl<R: Read + Seek + 'static> State for Start<R> {
     fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
         let header = read_header(&mut self.reader)?;
-        Ok(Some((Record::Header(header), self)))
+        let next_state = Headers(CommonState {
+            reader: self.reader,
+            endian: header.endian,
+            bias: header.bias,
+            compression: header.compression,
+            var_types: Vec::new(),
+        });
+        Ok(Some((Record::Header(header), Box::new(next_state))))
     }
 }
 
-struct Headers<R: Read + Seek> {
-    reader: R,
-    endian: Endian,
-    compression: Option<Compression>,
-    var_types: Vec<VarType>,
-}
+struct Headers<R: Read + Seek>(CommonState<R>);
 
 impl<R: Read + Seek + 'static> State for Headers<R> {
     fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-        let rec_type: u32 = self.endian.parse(read_bytes(&mut self.reader)?);
+        let endian = self.0.endian;
+        let rec_type: u32 = endian.parse(read_bytes(&mut self.0.reader)?);
         let record = match rec_type {
             2 => {
-                let variable = read_variable_record(&mut self.reader, self.endian)?;
-                self.var_types.push(VarType::from_width(variable.width));
+                let variable = read_variable_record(&mut self.0.reader, endian)?;
+                self.0.var_types.push(VarType::from_width(variable.width));
                 Record::Variable(variable)
             }
-            3 => Record::ValueLabel(read_value_label_record(&mut self.reader, self.endian)?),
-            4 => Record::VarIndexes(read_var_indexes_record(&mut self.reader, self.endian)?),
-            6 => Record::Document(read_document_record(&mut self.reader, self.endian)?),
-            7 => Record::Extension(read_extension_record(&mut self.reader, self.endian)?),
+            3 => Record::ValueLabel(read_value_label_record(&mut self.0.reader, endian)?),
+            4 => Record::VarIndexes(read_var_indexes_record(&mut self.0.reader, endian)?),
+            6 => Record::Document(read_document_record(&mut self.0.reader, endian)?),
+            7 => Record::Extension(read_extension_record(&mut self.0.reader, endian)?),
             999 => {
-                let _: [u8; 4] = read_bytes(&mut self.reader)?;
-                let next_state: Box<dyn State> = match self.compression {
-                    None => Box::new(Data {
-                        reader: self.reader,
-                        endian: self.endian,
-                        var_types: self.var_types,
-                    }),
-                    Some(Compression::Simple) => Box::new(CompressedData {
-                        reader: self.reader,
-                        endian: self.endian,
-                        var_types: self.var_types,
-                        codes: VecDeque::new(),
-                    }),
-                    Some(Compression::ZLib) => Box::new(CompressedData {
-                        reader: ZlibDecodeMultiple::new(self.reader),
-                        endian: self.endian,
-                        var_types: self.var_types,
-                        codes: VecDeque::new(),
-                    }),
+                let _: [u8; 4] = read_bytes(&mut self.0.reader)?;
+                let next_state: Box<dyn State> = match self.0.compression {
+                    None => Box::new(Data(self.0)),
+                    Some(Compression::Simple) => Box::new(CompressedData::new(self.0)),
+                    Some(Compression::ZLib) => Box::new(CompressedData::new(CommonState {
+                        reader: ZlibDecodeMultiple::new(self.0.reader),
+                        endian: self.0.endian,
+                        bias: self.0.bias,
+                        compression: self.0.compression,
+                        var_types: self.0.var_types
+                    })),
                 };
                 return Ok(Some((Record::EndOfHeaders, next_state)));
             }
             _ => {
                 return Err(Error::BadRecordType {
-                    offset: self.reader.stream_position()?,
+                    offset: self.0.reader.stream_position()?,
                     rec_type,
                 })
             }
@@ -288,55 +291,55 @@ impl<R: Read + Seek + 'static> State for Headers<R> {
     }
 }
 
-struct Data<R: Read + Seek> {
-    reader: R,
-    endian: Endian,
-    var_types: Vec<VarType>,
-}
+struct Data<R: Read + Seek>(CommonState<R>);
 
 impl<R: Read + Seek + 'static> State for Data<R> {
     fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-        let case_start = self.reader.stream_position()?;
-        let mut values = Vec::with_capacity(self.var_types.len());
-        for (i, &var_type) in self.var_types.iter().enumerate() {
-            let Some(raw) = try_read_bytes(&mut self.reader)? else {
+        let case_start = self.0.reader.stream_position()?;
+        let mut values = Vec::with_capacity(self.0.var_types.len());
+        for (i, &var_type) in self.0.var_types.iter().enumerate() {
+            let Some(raw) = try_read_bytes(&mut self.0.reader)? else {
                 if i == 0 {
                     return Ok(None);
                 } else {
-                    let offset = self.reader.stream_position()?;
+                    let offset = self.0.reader.stream_position()?;
                     return Err(Error::EofInCase {
                         offset,
                         case_ofs: offset - case_start,
-                        case_len: self.var_types.len() * 8,
+                        case_len: self.0.var_types.len() * 8,
                     });
                 }
             };
-            values.push(Value::from_raw(var_type, raw, self.endian));
+            values.push(Value::from_raw(var_type, raw, self.0.endian));
         }
         Ok(Some((Record::Case(values), self)))
     }
 }
 
 struct CompressedData<R: Read + Seek> {
-    reader: R,
-    endian: Endian,
-    var_types: Vec<VarType>,
+    common: CommonState<R>,
     codes: VecDeque<u8>,
 }
 
+impl<R: Read + Seek + 'static> CompressedData<R> {
+    fn new(common: CommonState<R>) -> CompressedData<R> {
+        CompressedData { common, codes: VecDeque::new() }
+    }
+}
+
 impl<R: Read + Seek + 'static> State for CompressedData<R> {
     fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-        let case_start = self.reader.stream_position()?;
-        let mut values = Vec::with_capacity(self.var_types.len());
-        let bias = 100.0; // XXX
-        for (i, &var_type) in self.var_types.iter().enumerate() {
+        let case_start = self.common.reader.stream_position()?;
+        let mut values = Vec::with_capacity(self.common.var_types.len());
+        for (i, &var_type) in self.common.var_types.iter().enumerate() {
             let value = loop {
                 let Some(code) = self.codes.pop_front() else {
-                    let Some(new_codes): Option<[u8; 8]> = try_read_bytes(&mut self.reader)? else {
+                    let Some(new_codes): Option<[u8; 8]> = try_read_bytes(&mut self.common.reader)?
+                    else {
                         if i == 0 {
                             return Ok(None);
                         } else {
-                            let offset = self.reader.stream_position()?;
+                            let offset = self.common.reader.stream_position()?;
                             return Err(Error::EofInCompressedCase {
                                 offset,
                                 case_ofs: offset - case_start,
@@ -349,16 +352,16 @@ impl<R: Read + Seek + 'static> State for CompressedData<R> {
                 match code {
                     0 => (),
                     1..=251 => match var_type {
-                        VarType::Number => break Value::Number(Some(code as f64 - bias)),
+                        VarType::Number => break Value::Number(Some(code as f64 - self.common.bias)),
                         VarType::String => {
-                            break Value::String(self.endian.to_bytes(code as f64 - bias))
+                            break Value::String(self.common.endian.to_bytes(code as f64 - self.common.bias))
                         }
                     },
                     252 => {
                         if i == 0 {
                             return Ok(None);
                         } else {
-                            let offset = self.reader.stream_position()?;
+                            let offset = self.common.reader.stream_position()?;
                             return Err(Error::PartialCompressedCase {
                                 offset,
                                 case_ofs: offset - case_start,
@@ -366,14 +369,18 @@ impl<R: Read + Seek + 'static> State for CompressedData<R> {
                         }
                     }
                     253 => {
-                        break Value::from_raw(var_type, read_bytes(&mut self.reader)?, self.endian)
+                        break Value::from_raw(
+                            var_type,
+                            read_bytes(&mut self.common.reader)?,
+                            self.common.endian,
+                        )
                     }
                     254 => match var_type {
                         VarType::String => break Value::String(*b"        "), // XXX EBCDIC
                         VarType::Number => {
                             return Err(Error::CompressedStringExpected {
                                 offset: case_start,
-                                case_ofs: self.reader.stream_position()? - case_start,
+                                case_ofs: self.common.reader.stream_position()? - case_start,
                             })
                         }
                     },
@@ -382,7 +389,7 @@ impl<R: Read + Seek + 'static> State for CompressedData<R> {
                         VarType::String => {
                             return Err(Error::CompressedNumberExpected {
                                 offset: case_start,
-                                case_ofs: self.reader.stream_position()? - case_start,
+                                case_ofs: self.common.reader.stream_position()? - case_start,
                             })
                         }
                     },
@@ -491,16 +498,16 @@ fn read_header<R: Read>(r: &mut R) -> Result<Header, Error> {
 
     let eye_catcher: [u8; 60] = read_bytes(r)?;
     let layout_code: [u8; 4] = read_bytes(r)?;
-    let endianness = Endian::identify_u32(2, layout_code)
+    let endian = Endian::identify_u32(2, layout_code)
         .or_else(|| Endian::identify_u32(2, layout_code))
         .ok_or_else(|| Error::NotASystemFile)?;
-    let layout_code = endianness.parse(layout_code);
+    let layout_code = endian.parse(layout_code);
 
-    let nominal_case_size: u32 = endianness.parse(read_bytes(r)?);
+    let nominal_case_size: u32 = endian.parse(read_bytes(r)?);
     let nominal_case_size =
         (nominal_case_size <= i32::MAX as u32 / 16).then_some(nominal_case_size);
 
-    let compression_code: u32 = endianness.parse(read_bytes(r)?);
+    let compression_code: u32 = endian.parse(read_bytes(r)?);
     let compression = match (magic, compression_code) {
         (Magic::ZSAV, 2) => Some(Compression::ZLib),
         (Magic::ZSAV, code) => return Err(Error::InvalidZsavCompression(code)),
@@ -509,13 +516,13 @@ fn read_header<R: Read>(r: &mut R) -> Result<Header, Error> {
         (_, code) => return Err(Error::InvalidSavCompression(code)),
     };
 
-    let weight_index: u32 = endianness.parse(read_bytes(r)?);
+    let weight_index: u32 = endian.parse(read_bytes(r)?);
     let weight_index = (weight_index > 0).then_some(weight_index - 1);
 
-    let n_cases: u32 = endianness.parse(read_bytes(r)?);
+    let n_cases: u32 = endian.parse(read_bytes(r)?);
     let n_cases = (n_cases < i32::MAX as u32 / 2).then_some(n_cases);
 
-    let bias: f64 = endianness.parse(read_bytes(r)?);
+    let bias: f64 = endian.parse(read_bytes(r)?);
 
     let creation_date: [u8; 9] = read_bytes(r)?;
     let creation_time: [u8; 8] = read_bytes(r)?;
@@ -534,7 +541,7 @@ fn read_header<R: Read>(r: &mut R) -> Result<Header, Error> {
         creation_time,
         eye_catcher,
         file_label,
-        endianness,
+        endian,
     })
 }