cleanup
authorBen Pfaff <blp@cs.stanford.edu>
Wed, 22 Nov 2023 00:10:24 +0000 (16:10 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Wed, 22 Nov 2023 00:10:24 +0000 (16:10 -0800)
rust/src/main.rs
rust/src/raw.rs

index 2251b760fe7dab58b5ae78d60b9cc7258853b906..404e96d57d07becaf09c4688319a0ac3739ceee4 100644 (file)
@@ -18,7 +18,7 @@ use anyhow::Result;
 use clap::{Parser, ValueEnum};
 use encoding_rs::Encoding;
 use pspp::cooked::decode;
-use pspp::raw::{Reader, Record};
+use pspp::raw::{Reader, Record, Magic};
 use std::fs::File;
 use std::io::BufReader;
 use std::path::{Path, PathBuf};
@@ -59,6 +59,7 @@ fn parse_encoding(arg: &str) -> Result<&'static Encoding, UnknownEncodingError>
 
 #[derive(Clone, Copy, Debug, Default, ValueEnum)]
 enum Mode {
+    Identify,
     Raw,
     #[default]
     Cooked,
@@ -83,14 +84,24 @@ fn dissect(file_name: &Path, max_cases: u64, mode: Mode, encoding: Option<&'stat
     let reader = BufReader::new(reader);
     let mut reader = Reader::new(reader)?;
 
-    let headers: Vec<Record> = reader.collect_headers()?;
     match mode {
+        Mode::Identify => {
+            let Record::Header(header) = reader.next().unwrap()? else { unreachable!() };
+            match header.magic {
+                Magic::Sav => println!("SPSS System File"),
+                Magic::Zsav => println!("SPSS System File with Zlib compression"),
+                Magic::Ebcdic => println!("EBCDIC-encoded SPSS System File"),
+            }
+            return Ok(())
+        }
         Mode::Raw => {
+            let headers: Vec<Record> = reader.collect_headers()?;
             for header in headers {
                 println!("{header:?}");
             }
         }
         Mode::Cooked => {
+            let headers: Vec<Record> = reader.collect_headers()?;
             let headers = decode(headers, encoding, &|e| panic!("{e}"))?;
             for header in headers {
                 println!("{header:?}");
@@ -99,10 +110,10 @@ fn dissect(file_name: &Path, max_cases: u64, mode: Mode, encoding: Option<&'stat
     }
 
     for _ in 0..max_cases {
-        let Some(Ok(Record::Case(data))) = reader.next() else {
+        let Some(Ok(record)) = reader.next() else {
             break;
         };
-        println!("{:?}", data);
+        println!("{:?}", record);
     }
     Ok(())
 }
index 8722febe887afa1ab73f7e34ab3f526a4b831ac2..ed246717cef3f94b45c4330cf8ac8283cec4eba2 100644 (file)
@@ -283,8 +283,8 @@ impl HeaderRecord {
 
         let compression_code: u32 = endian.parse(read_bytes(r)?);
         let compression = match (magic, compression_code) {
-            (Magic::ZSAV, 2) => Some(Compression::ZLib),
-            (Magic::ZSAV, code) => return Err(Error::InvalidZsavCompression(code)),
+            (Magic::Zsav, 2) => Some(Compression::ZLib),
+            (Magic::Zsav, code) => return Err(Error::InvalidZsavCompression(code)),
             (_, 0) => None,
             (_, 1) => Some(Compression::Simple),
             (_, code) => return Err(Error::InvalidSavCompression(code)),
@@ -328,27 +328,35 @@ impl Header for HeaderRecord {
 }
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash)]
-pub struct Magic([u8; 4]);
+pub enum Magic {
+    /// Regular system file.
+    Sav,
+
+    /// System file with Zlib-compressed data.
+    Zsav,
+
+    /// EBCDIC-encoded system file.
+    Ebcdic,
+}
 
 impl Magic {
     /// Magic number for a regular system file.
-    pub const SAV: Magic = Magic(*b"$FL2");
+    pub const SAV: [u8; 4] = *b"$FL2";
 
     /// Magic number for a system file that contains zlib-compressed data.
-    pub const ZSAV: Magic = Magic(*b"$FL3");
+    pub const ZSAV: [u8; 4] = *b"$FL3";
 
-    /// Magic number for an EBDIC-encoded system file.  This is `$FL2` encoded
+    /// Magic number for an EBCDIC-encoded system file.  This is `$FL2` encoded
     /// in EBCDIC.
-    pub const EBCDIC: Magic = Magic([0x5b, 0xc6, 0xd3, 0xf2]);
+    pub const EBCDIC: [u8; 4] = [0x5b, 0xc6, 0xd3, 0xf2];
 }
 
 impl Debug for Magic {
     fn fmt(&self, f: &mut Formatter) -> FmtResult {
         let s = match *self {
-            Magic::SAV => "$FL2",
-            Magic::ZSAV => "$FL3",
-            Magic::EBCDIC => "($FL2 in EBCDIC)",
-            _ => return write!(f, "{:?}", self.0),
+            Magic::Sav => "$FL2",
+            Magic::Zsav => "$FL3",
+            Magic::Ebcdic => "($FL2 in EBCDIC)",
         };
         write!(f, "{s}")
     }
@@ -358,9 +366,10 @@ impl TryFrom<[u8; 4]> for Magic {
     type Error = Error;
 
     fn try_from(value: [u8; 4]) -> Result<Self, Self::Error> {
-        let magic = Magic(value);
-        match magic {
-            Magic::SAV | Magic::ZSAV | Magic::EBCDIC => Ok(magic),
+        match value {
+            Magic::SAV => Ok(Magic::Sav),
+            Magic::ZSAV => Ok(Magic::Zsav),
+            Magic::EBCDIC => Ok(Magic::Ebcdic),
             _ => Err(Error::BadMagic(value)),
         }
     }
@@ -455,7 +464,8 @@ mod state {
     impl<R: Read + Seek + 'static> State for ZlibHeader<R> {
         fn read(mut self: Box<Self>) -> Result<Option<(Record, Box<dyn State>)>, Error> {
             let zheader = ZHeader::read(&mut self.0.reader, self.0.endian)?;
-            Ok(Some((Record::ZHeader(zheader), self)))
+            let next_state = Box::new(ZlibTrailer(self.0, zheader.clone()));
+            Ok(Some((Record::ZHeader(zheader), next_state)))
         }
     }
 
@@ -1253,10 +1263,10 @@ pub enum MultipleResponseType {
 
 impl MultipleResponseType {
     fn parse(input: &[u8]) -> Result<(MultipleResponseType, &[u8]), Error> {
-        let (mr_type, input) = match input.first() {
-            Some(b'C') => (MultipleResponseType::MultipleCategory, &input[1..]),
-            Some(b'D') => {
-                let (value, input) = parse_counted_string(&input[1..])?;
+        let (mr_type, input) = match input.split_first() {
+            Some((b'C', input)) => (MultipleResponseType::MultipleCategory, input),
+            Some((b'D', input)) => {
+                let (value, input) = parse_counted_string(input)?;
                 (
                     MultipleResponseType::MultipleDichotomy {
                         value,
@@ -1265,11 +1275,7 @@ impl MultipleResponseType {
                     input,
                 )
             }
-            Some(b'E') => {
-                let Some(b' ') = input.get(1) else {
-                    return Err(Error::TBD);
-                };
-                let input = &input[2..];
+            Some((b'E', input)) => {
                 let (labels, input) = if let Some(rest) = input.strip_prefix(b" 1 ") {
                     (CategoryLabels::CountedValues, rest)
                 } else if let Some(rest) = input.strip_prefix(b" 11 ") {
@@ -1304,23 +1310,25 @@ impl MultipleResponseSet {
         };
         let (name, input) = input.split_at(equals);
         let (mr_type, input) = MultipleResponseType::parse(input)?;
-        let Some(b' ') = input.first() else {
+        let Some(input) = input.strip_prefix(b" ") else {
             return Err(Error::TBD);
         };
-        let (label, mut input) = parse_counted_string(&input[1..])?;
+        let (label, mut input) = parse_counted_string(input)?;
         let mut vars = Vec::new();
-        while input.first() == Some(&b' ') {
-            input = &input[1..];
-            let Some(length) = input.iter().position(|b| b" \n".contains(b)) else {
-                return Err(Error::TBD);
-            };
-            if length > 0 {
-                vars.push(input[..length].into());
+        while input.first() != Some(&b'\n') {
+            match input.split_first() {
+                Some((b' ', rest)) => {
+                    let Some(length) = rest.iter().position(|b| b" \n".contains(b)) else {
+                        return Err(Error::TBD);
+                    };
+                    let (var, rest) = rest.split_at(length);
+                    if !var.is_empty() {
+                        vars.push(var.into());
+                    }
+                    input = rest;
+                }
+                _ => return Err(Error::TBD),
             }
-            input = &input[length..];
-        }
-        if input.first() != Some(&b'\n') {
-            return Err(Error::TBD);
         }
         while input.first() == Some(&b'\n') {
             input = &input[1..];
@@ -1593,27 +1601,17 @@ impl Extension {
             data,
         };
         match subtype {
-            IntegerInfoRecord::SUBTYPE => Ok(IntegerInfoRecord::parse(
-                &extension, endian,
-            )?),
-            FloatInfoRecord::SUBTYPE => Ok(FloatInfoRecord::parse(
-                &extension, endian,
-            )?),
-            VarDisplayRecord::SUBTYPE => Ok(VarDisplayRecord::parse(
-                &extension, endian,
-            )?),
-            MultipleResponseRecord::SUBTYPE | 19 => Ok(
-                MultipleResponseRecord::parse(&extension, endian)?,
-            ),
-            LongStringValueLabelRecord::SUBTYPE => Ok(
-                LongStringValueLabelRecord::parse(&extension, endian)?,
-            ),
-            EncodingRecord::SUBTYPE => {
-                Ok(EncodingRecord::parse(&extension, endian)?)
+            IntegerInfoRecord::SUBTYPE => Ok(IntegerInfoRecord::parse(&extension, endian)?),
+            FloatInfoRecord::SUBTYPE => Ok(FloatInfoRecord::parse(&extension, endian)?),
+            VarDisplayRecord::SUBTYPE => Ok(VarDisplayRecord::parse(&extension, endian)?),
+            MultipleResponseRecord::SUBTYPE | 19 => {
+                Ok(MultipleResponseRecord::parse(&extension, endian)?)
+            }
+            LongStringValueLabelRecord::SUBTYPE => {
+                Ok(LongStringValueLabelRecord::parse(&extension, endian)?)
             }
-            NumberOfCasesRecord::SUBTYPE => Ok(NumberOfCasesRecord::parse(
-                &extension, endian,
-            )?),
+            EncodingRecord::SUBTYPE => Ok(EncodingRecord::parse(&extension, endian)?),
+            NumberOfCasesRecord::SUBTYPE => Ok(NumberOfCasesRecord::parse(&extension, endian)?),
             5 => Ok(Record::VariableSets(extension.into())),
             10 => Ok(Record::ProductInfo(extension.into())),
             13 => Ok(Record::LongNames(extension.into())),