actually use the new state machine
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 29 Jul 2023 03:47:37 +0000 (20:47 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 29 Jul 2023 03:47:37 +0000 (20:47 -0700)
rust/src/lib.rs

index 3932dcf5b24995767c8a82ace93508815a58fe76..49887a35659de18830db8dde19d1437605c34e6f 100644 (file)
@@ -1,11 +1,12 @@
 #![allow(unused_variables)]
 use endian::{Endian, Parse, ToBytes};
-use flate2::bufread::ZlibDecoder;
+use flate2::read::ZlibDecoder;
 use num::Integer;
 use num_derive::FromPrimitive;
 use std::{
     collections::VecDeque,
-    io::{BufReader, Error as IoError, Read, Seek, SeekFrom},
+    io::{Error as IoError, Read, Seek, SeekFrom},
+    iter::FusedIterator,
 };
 use thiserror::Error;
 
@@ -218,29 +219,23 @@ impl VarType {
     }
 }
 
-pub struct Reader<R: Read + Seek> {
-    r: BufReader<R>,
-    var_types: Vec<VarType>,
-    state: ReaderState,
-}
-
 trait State {
     fn read(self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error>;
 }
 
 struct Start<R: Read + Seek> {
-    r: BufReader<R>,
+    reader: R,
 }
 
 impl<R: Read + Seek + 'static> State for Start<R> {
     fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
-        let header = read_header(&mut self.r)?;
+        let header = read_header(&mut self.reader)?;
         Ok(Some((Record::Header(header), self)))
     }
 }
 
 struct Headers<R: Read + Seek> {
-    reader: BufReader<R>,
+    reader: R,
     endian: Endian,
     compression: Option<Compression>,
     var_types: Vec<VarType>,
@@ -294,7 +289,7 @@ impl<R: Read + Seek + 'static> State for Headers<R> {
 }
 
 struct Data<R: Read + Seek> {
-    reader: BufReader<R>,
+    reader: R,
     endian: Endian,
     var_types: Vec<VarType>,
 }
@@ -355,7 +350,9 @@ impl<R: Read + Seek + 'static> State for CompressedData<R> {
                     0 => (),
                     1..=251 => match var_type {
                         VarType::Number => break Value::Number(Some(code as f64 - bias)),
-                        VarType::String => break Value::String(self.endian.to_bytes(code as f64 - bias)),
+                        VarType::String => {
+                            break Value::String(self.endian.to_bytes(code as f64 - bias))
+                        }
                     },
                     252 => {
                         if i == 0 {
@@ -368,7 +365,9 @@ impl<R: Read + Seek + 'static> State for CompressedData<R> {
                             });
                         }
                     }
-                    253 => break Value::from_raw(var_type, read_bytes(&mut self.reader)?, self.endian),
+                    253 => {
+                        break Value::from_raw(var_type, read_bytes(&mut self.reader)?, self.endian)
+                    }
                     254 => match var_type {
                         VarType::String => break Value::String(*b"        "), // XXX EBCDIC
                         VarType::Number => {
@@ -399,14 +398,14 @@ struct ZlibDecodeMultiple<R>
 where
     R: Read + Seek,
 {
-    reader: Option<ZlibDecoder<BufReader<R>>>,
+    reader: Option<ZlibDecoder<R>>,
 }
 
 impl<R> ZlibDecodeMultiple<R>
 where
     R: Read + Seek,
 {
-    fn new(reader: BufReader<R>) -> ZlibDecodeMultiple<R> {
+    fn new(reader: R) -> ZlibDecodeMultiple<R> {
         ZlibDecodeMultiple {
             reader: Some(ZlibDecoder::new(reader)),
         }
@@ -439,34 +438,6 @@ where
     }
 }
 
-/*
-impl<R> BufRead for ZlibDecodeMultiple<R>
-where
-    R: Read + Seek,
-{
-    fn fill_buf(&mut self) -> Result<&[u8], IoError> {
-        self.reader.as_mut().unwrap().fill_buf()
-    }
-    fn consume(&mut self, amt: usize) {
-        self.reader.as_mut().unwrap().consume(amt)
-    }
-}*/
-
-enum ReaderState {
-    Start,
-    Headers(Endian, Option<Compression>),
-    Data(Endian),
-    CompressedData(Endian, VecDeque<u8>),
-    ZHeader(Endian),
-    ZTrailer {
-        endian: Endian,
-        ztrailer_ofs: u64,
-        ztrailer_len: u64,
-    },
-    //ZData,
-    End,
-}
-
 #[derive(Copy, Clone)]
 pub enum Value {
     Number(Option<f64>),
@@ -485,190 +456,35 @@ impl Value {
     }
 }
 
-impl<R: Read + Seek> Reader<R> {
-    pub fn new(r: R) -> Result<Reader<R>, Error> {
+pub struct Reader {
+    state: Option<Box<dyn State>>,
+}
+
+impl Reader {
+    pub fn new<R: Read + Seek + 'static>(reader: R) -> Result<Reader, Error> {
         Ok(Reader {
-            r: BufReader::new(r),
-            var_types: Vec::new(),
-            state: ReaderState::Start,
+            state: Some(Box::new(Start { reader })),
         })
     }
-    fn _next(&mut self) -> Result<Option<Record>, Error> {
-        match self.state {
-            ReaderState::Start => {
-                let header = read_header(&mut self.r)?;
-                self.state = ReaderState::Headers(header.endianness, header.compression);
-                Ok(Some(Record::Header(header)))
-            }
-            ReaderState::Headers(endian, compression) => {
-                let rec_type: u32 = endian.parse(read_bytes(&mut self.r)?);
-                let record = match rec_type {
-                    2 => {
-                        let variable = read_variable_record(&mut self.r, endian)?;
-                        self.var_types.push(VarType::from_width(variable.width));
-                        Record::Variable(variable)
-                    }
-                    3 => Record::ValueLabel(read_value_label_record(&mut self.r, endian)?),
-                    4 => Record::VarIndexes(read_var_indexes_record(&mut self.r, endian)?),
-                    6 => Record::Document(read_document_record(&mut self.r, endian)?),
-                    7 => Record::Extension(read_extension_record(&mut self.r, endian)?),
-                    999 => {
-                        let _: [u8; 4] = read_bytes(&mut self.r)?;
-                        self.state = match compression {
-                            None => ReaderState::Data(endian),
-                            Some(Compression::Simple) => {
-                                ReaderState::CompressedData(endian, VecDeque::new())
-                            }
-                            Some(Compression::ZLib) => ReaderState::ZHeader(endian),
-                        };
-                        return Ok(Some(Record::EndOfHeaders));
-                    }
-                    _ => {
-                        return Err(Error::BadRecordType {
-                            offset: self.r.stream_position()?,
-                            rec_type,
-                        })
-                    }
-                };
-                Ok(Some(record))
-            }
-            ReaderState::Data(endian) => {
-                let case_start = self.r.stream_position()?;
-                let mut values = Vec::with_capacity(self.var_types.len());
-                for (i, &var_type) in self.var_types.iter().enumerate() {
-                    let Some(raw) = try_read_bytes(&mut self.r)? else {
-                        if i == 0 {
-                            return Ok(None);
-                        } else {
-                            let offset = self.r.stream_position()?;
-                            return Err(Error::EofInCase {
-                                offset,
-                                case_ofs: offset - case_start,
-                                case_len: self.var_types.len() * 8,
-                            });
-                        }
-                    };
-                    values.push(Value::from_raw(var_type, raw, endian));
-                }
-                Ok(Some(Record::Case(values)))
-            }
-            ReaderState::CompressedData(endian, ref mut codes) => {
-                let case_start = self.r.stream_position()?;
-                let mut values = Vec::with_capacity(self.var_types.len());
-                let bias = 100.0; // XXX
-                for (i, &var_type) in self.var_types.iter().enumerate() {
-                    let value = loop {
-                        let Some(code) = codes.pop_front() else {
-                            let Some(new_codes): Option<[u8; 8]> = try_read_bytes(&mut self.r)?
-                            else {
-                                if i == 0 {
-                                    return Ok(None);
-                                } else {
-                                    let offset = self.r.stream_position()?;
-                                    return Err(Error::EofInCompressedCase {
-                                        offset,
-                                        case_ofs: offset - case_start,
-                                    });
-                                }
-                            };
-                            codes.extend(new_codes.into_iter());
-                            continue;
-                        };
-                        match code {
-                            0 => (),
-                            1..=251 => match var_type {
-                                VarType::Number => break Value::Number(Some(code as f64 - bias)),
-                                VarType::String => {
-                                    break Value::String(endian.to_bytes(code as f64 - bias))
-                                }
-                            },
-                            252 => {
-                                if i == 0 {
-                                    return Ok(None);
-                                } else {
-                                    let offset = self.r.stream_position()?;
-                                    return Err(Error::PartialCompressedCase {
-                                        offset,
-                                        case_ofs: offset - case_start,
-                                    });
-                                }
-                            }
-                            253 => {
-                                break Value::from_raw(var_type, read_bytes(&mut self.r)?, endian)
-                            }
-                            254 => match var_type {
-                                VarType::String => break Value::String(*b"        "), // XXX EBCDIC
-                                VarType::Number => {
-                                    return Err(Error::CompressedStringExpected {
-                                        offset: case_start,
-                                        case_ofs: self.r.stream_position()? - case_start,
-                                    })
-                                }
-                            },
-                            255 => match var_type {
-                                VarType::Number => break Value::Number(None),
-                                VarType::String => {
-                                    return Err(Error::CompressedNumberExpected {
-                                        offset: case_start,
-                                        case_ofs: self.r.stream_position()? - case_start,
-                                    })
-                                }
-                            },
-                        }
-                    };
-                    values.push(value);
-                }
-                Ok(Some(Record::Case(values)))
-            }
-            ReaderState::ZHeader(endian) => {
-                let zheader = read_zheader(&mut self.r, endian)?;
-                self.state = ReaderState::ZTrailer {
-                    endian,
-                    ztrailer_ofs: zheader.ztrailer_offset,
-                    ztrailer_len: zheader.ztrailer_len,
-                };
-                Ok(Some(Record::ZHeader(zheader)))
-            }
-            ReaderState::ZTrailer {
-                endian,
-                ztrailer_ofs,
-                ztrailer_len,
-            } => {
-                //self.state = ReaderState::ZData;
-                match read_ztrailer(&mut self.r, endian, ztrailer_ofs, ztrailer_len)? {
-                    Some(ztrailer) => Ok(Some(Record::ZTrailer(ztrailer))),
-                    None => self._next(),
-                }
-            }
-            /*
-                        ReaderState::ZData(zlib_decoder) => {
-                            let zlib_decoder = zlib_decoder.unwrap_or_else(
-                        },
-            */
-            ReaderState::End => Ok(None),
-        }
-    }
 }
 
-impl<R: Read + Seek> Iterator for Reader<R> {
+impl Iterator for Reader {
     type Item = Result<Record, Error>;
 
     fn next(&mut self) -> Option<Self::Item> {
-        let retval = self._next();
-        match retval {
-            Ok(None) => {
-                self.state = ReaderState::End;
-                None
-            }
-            Ok(Some(record)) => Some(Ok(record)),
-            Err(error) => {
-                self.state = ReaderState::End;
-                Some(Err(error))
+        match self.state.take()?.read() {
+            Ok(Some((record, next_state))) => {
+                self.state = Some(next_state);
+                return Some(Ok(record));
             }
+            Ok(None) => return None,
+            Err(error) => return Some(Err(error)),
         }
     }
 }
 
+impl FusedIterator for Reader {}
+
 fn read_header<R: Read>(r: &mut R) -> Result<Header, Error> {
     let magic: [u8; 4] = read_bytes(r)?;
     let magic: Magic = magic.try_into().map_err(|_| Error::NotASystemFile)?;
@@ -748,10 +564,7 @@ pub struct Variable {
     pub label: Option<Vec<u8>>,
 }
 
-fn read_variable_record<R: Read + Seek>(
-    r: &mut BufReader<R>,
-    endian: Endian,
-) -> Result<Variable, Error> {
+fn read_variable_record<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<Variable, Error> {
     let offset = r.stream_position()?;
     let width: i32 = endian.parse(read_bytes(r)?);
     let has_variable_label: u32 = endian.parse(read_bytes(r)?);
@@ -829,10 +642,7 @@ impl ValueLabel {
     pub const MAX: u32 = u32::MAX / 8;
 }
 
-fn read_value_label_record<R: Read + Seek>(
-    r: &mut BufReader<R>,
-    endian: Endian,
-) -> Result<ValueLabel, Error> {
+fn read_value_label_record<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<ValueLabel, Error> {
     let offset = r.stream_position()?;
     let n: u32 = endian.parse(read_bytes(r)?);
     if n > ValueLabel::MAX {
@@ -870,10 +680,7 @@ impl VarIndexes {
     pub const MAX: u32 = u32::MAX / 8;
 }
 
-fn read_var_indexes_record<R: Read + Seek>(
-    r: &mut BufReader<R>,
-    endian: Endian,
-) -> Result<VarIndexes, Error> {
+fn read_var_indexes_record<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<VarIndexes, Error> {
     let offset = r.stream_position()?;
     let n: u32 = endian.parse(read_bytes(r)?);
     if n > VarIndexes::MAX {
@@ -905,10 +712,7 @@ pub struct Document {
     pub lines: Vec<[u8; DOC_LINE_LEN as usize]>,
 }
 
-fn read_document_record<R: Read + Seek>(
-    r: &mut BufReader<R>,
-    endian: Endian,
-) -> Result<Document, Error> {
+fn read_document_record<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<Document, Error> {
     let offset = r.stream_position()?;
     let n: u32 = endian.parse(read_bytes(r)?);
     match n {
@@ -1012,10 +816,7 @@ fn extension_record_size_requirements(extension: ExtensionType) -> (u32, u32) {
     }
 }
 
-fn read_extension_record<R: Read + Seek>(
-    r: &mut BufReader<R>,
-    endian: Endian,
-) -> Result<Extension, Error> {
+fn read_extension_record<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<Extension, Error> {
     let subtype = endian.parse(read_bytes(r)?);
     let offset = r.stream_position()?;
     let size: u32 = endian.parse(read_bytes(r)?);
@@ -1053,7 +854,7 @@ pub struct ZHeader {
     pub ztrailer_len: u64,
 }
 
-fn read_zheader<R: Read + Seek>(r: &mut BufReader<R>, endian: Endian) -> Result<ZHeader, Error> {
+fn read_zheader<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<ZHeader, Error> {
     let offset = r.stream_position()?;
     let zheader_offset: u64 = endian.parse(read_bytes(r)?);
     let ztrailer_offset: u64 = endian.parse(read_bytes(r)?);
@@ -1102,7 +903,7 @@ pub struct ZBlock {
 }
 
 fn read_ztrailer<R: Read + Seek>(
-    r: &mut BufReader<R>,
+    r: &mut R,
     endian: Endian,
     ztrailer_ofs: u64,
     ztrailer_len: u64,
@@ -1166,7 +967,7 @@ fn read_bytes<const N: usize, R: Read>(r: &mut R) -> Result<[u8; N], IoError> {
     Ok(buf)
 }
 
-fn read_vec<R: Read>(r: &mut BufReader<R>, n: usize) -> Result<Vec<u8>, IoError> {
+fn read_vec<R: Read>(r: &mut R, n: usize) -> Result<Vec<u8>, IoError> {
     let mut vec = vec![0; n];
     r.read_exact(&mut vec)?;
     Ok(vec)