work
authorBen Pfaff <blp@cs.stanford.edu>
Mon, 7 Aug 2023 05:39:05 +0000 (22:39 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Mon, 7 Aug 2023 05:39:05 +0000 (22:39 -0700)
rust/src/main.rs
rust/src/raw.rs

index 0bb33bbbcc23db9d4a520403e37e883a2b0d0453..319d7aea7c0f32a020db403a21fe60fad96cac67 100644 (file)
  * You should have received a copy of the GNU General Public License
  * along with this program.  If not, see <http://www.gnu.org/licenses/>. */
 
-use anyhow::{Result};
+use anyhow::Result;
 use clap::Parser;
-use pspp::raw::Reader;
+use pspp::{
+    raw::{Reader, Record},
+};
 use std::fs::File;
 use std::io::BufReader;
 use std::path::{Path, PathBuf};
@@ -28,7 +30,7 @@ use std::str;
 struct Args {
     /// Maximum number of cases to print.
     #[arg(long = "data", default_value_t = 0)]
-    max_cases: usize,
+    max_cases: u64,
 
     /// Files to dissect.
     #[arg(required = true)]
@@ -36,20 +38,33 @@ struct Args {
 }
 
 fn main() -> Result<()> {
-    let Args { files, .. } = Args::parse();
+    let Args { max_cases, files } = Args::parse();
 
     for file in files {
-        dissect(&file)?;
+        dissect(&file, max_cases)?;
     }
     Ok(())
 }
 
-fn dissect(file_name: &Path) -> Result<()> {
+fn dissect(file_name: &Path, max_cases: u64) -> Result<()> {
     let reader = File::open(file_name)?;
     let reader = BufReader::new(reader);
-    let reader = Reader::new(reader)?;
-    for record in reader {
+    let mut reader = Reader::new(reader)?;
+    let records: Vec<Record> = reader.collect_headers()?;
+
+    let mut n_cases = 0;
+    for record in records {
         println!("{record:?}");
+        match record {
+            Record::EndOfHeaders(_) if max_cases == 0 => break,
+            Record::Case(_) => {
+                n_cases += 1;
+                if n_cases >= max_cases {
+                    break;
+                }
+            }
+            _ => (),
+        }
     }
     Ok(())
 }
index 4dbddec151f618fd93cc932d631c8ee5fd04f76c..bae59656f05474b0c24874fe2401b97b563b0626 100644 (file)
@@ -27,6 +27,20 @@ pub enum Record {
     Variable(Variable),
     ValueLabel(ValueLabel),
     VarIndexes(VarIndexes),
+    IntegerInfo(IntegerInfo),
+    FloatInfo(FloatInfo),
+    VariableSets(UnencodedString),
+    VarDisplay(VarDisplayRecord),
+    MultipleResponse(MultipleResponseRecord),
+    LongStringValueLabels(LongStringValueLabelRecord),
+    Encoding(EncodingRecord),
+    NumberOfCases(NumberOfCasesRecord),
+    ProductInfo(UnencodedString),
+    LongNames(UnencodedString),
+    LongStrings(UnencodedString),
+    FileAttributes(UnencodedString),
+    VariableAttributes(UnencodedString),
+    TextExtension(TextExtension),
     Extension(Extension),
     EndOfHeaders(u32),
     ZHeader(ZHeader),
@@ -42,7 +56,7 @@ impl Record {
             3 => Ok(Record::ValueLabel(ValueLabel::read(reader, endian)?)),
             4 => Ok(Record::VarIndexes(VarIndexes::read(reader, endian)?)),
             6 => Ok(Record::Document(Document::read(reader, endian)?)),
-            7 => Ok(Record::Extension(Extension::read(reader, endian)?)),
+            7 => Ok(Extension::read(reader, endian)?),
             999 => Ok(Record::EndOfHeaders(endian.parse(read_bytes(reader)?))),
             _ => Err(Error::BadRecordType {
                 offset: reader.stream_position()?,
@@ -584,6 +598,16 @@ impl Reader {
             state: Some(state::new(reader)),
         })
     }
+    pub fn collect_headers(&mut self) -> Result<Vec<Record>, Error> {
+        let mut headers = Vec::new();
+        for record in self {
+            match record? {
+                Record::EndOfHeaders(_) => break,
+                r => headers.push(r),
+            };
+        }
+        Ok(headers)
+    }
 }
 
 impl Iterator for Reader {
@@ -992,48 +1016,6 @@ impl Document {
     }
 }
 
-/*
-#[derive(FromPrimitive)]
-enum ExtensionType {
-    /// Machine integer info.
-    Integer = 3,
-    /// Machine floating-point info.
-    Float = 4,
-    /// Variable sets.
-    VarSets = 5,
-    /// DATE.
-    Date = 6,
-    /// Multiple response sets.
-    Mrsets = 7,
-    /// SPSS Data Entry.
-    DataEntry = 8,
-    /// Extra product info text.
-    ProductInfo = 10,
-    /// Variable display parameters.
-    Display = 11,
-    /// Long variable names.
-    LongNames = 13,
-    /// Long strings.
-    LongStrings = 14,
-    /// Extended number of cases.
-    Ncases = 16,
-    /// Data file attributes.
-    FileAttrs = 17,
-    /// Variable attributes.
-    VarAttrs = 18,
-    /// Multiple response sets (extended).
-    Mrsets2 = 19,
-    /// Character encoding.
-    Encoding = 20,
-    /// Value labels for long strings.
-    LongLabels = 21,
-    /// Missing values for long strings.
-    LongMissing = 22,
-    /// "Format properties in dataview table".
-    Dataview = 24,
-}
- */
-
 trait TextRecord
 where
     Self: Sized,
@@ -1046,12 +1028,14 @@ trait ExtensionRecord
 where
     Self: Sized,
 {
+    const SUBTYPE: u32;
     const SIZE: Option<u32>;
     const COUNT: Option<u32>;
     const NAME: &'static str;
     fn parse(ext: &Extension, endian: Endian, warn: impl Fn(Error)) -> Result<Self, Error>;
 }
 
+#[derive(Clone, Debug)]
 pub struct IntegerInfo {
     pub version: (i32, i32, i32),
     pub machine_code: i32,
@@ -1062,6 +1046,7 @@ pub struct IntegerInfo {
 }
 
 impl ExtensionRecord for IntegerInfo {
+    const SUBTYPE: u32 = 3;
     const SIZE: Option<u32> = Some(4);
     const COUNT: Option<u32> = Some(8);
     const NAME: &'static str = "integer record";
@@ -1084,6 +1069,7 @@ impl ExtensionRecord for IntegerInfo {
     }
 }
 
+#[derive(Clone, Debug)]
 pub struct FloatInfo {
     pub sysmis: f64,
     pub highest: f64,
@@ -1091,6 +1077,7 @@ pub struct FloatInfo {
 }
 
 impl ExtensionRecord for FloatInfo {
+    const SUBTYPE: u32 = 4;
     const SIZE: Option<u32> = Some(8);
     const COUNT: Option<u32> = Some(3);
     const NAME: &'static str = "floating point record";
@@ -1110,10 +1097,12 @@ impl ExtensionRecord for FloatInfo {
     }
 }
 
+#[derive(Clone, Debug)]
 pub enum CategoryLabels {
     VarLabels,
     CountedValues,
 }
+#[derive(Clone, Debug)]
 pub enum MultipleResponseType {
     MultipleDichotomy {
         value: UnencodedString,
@@ -1121,6 +1110,7 @@ pub enum MultipleResponseType {
     },
     MultipleCategory,
 }
+#[derive(Clone, Debug)]
 pub struct MultipleResponseSet {
     pub name: UnencodedString,
     pub label: UnencodedString,
@@ -1202,9 +1192,11 @@ impl MultipleResponseSet {
     }
 }
 
-pub struct MultipleResponseSets(Vec<MultipleResponseSet>);
+#[derive(Clone, Debug)]
+pub struct MultipleResponseRecord(Vec<MultipleResponseSet>);
 
-impl ExtensionRecord for MultipleResponseSets {
+impl ExtensionRecord for MultipleResponseRecord {
+    const SUBTYPE: u32 = 7;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
     const NAME: &'static str = "multiple response set record";
@@ -1219,7 +1211,7 @@ impl ExtensionRecord for MultipleResponseSets {
             sets.push(set);
             input = rest;
         }
-        Ok(MultipleResponseSets(sets))
+        Ok(MultipleResponseRecord(sets))
     }
 }
 
@@ -1252,9 +1244,11 @@ impl TextRecord for ExtraProductInfo {
     }
 }
 
+#[derive(Clone, Debug)]
 pub struct VarDisplayRecord(Vec<u32>);
 
 impl ExtensionRecord for VarDisplayRecord {
+    const SUBTYPE: u32 = 11;
     const SIZE: Option<u32> = Some(4);
     const COUNT: Option<u32> = None;
     const NAME: &'static str = "variable display record";
@@ -1366,6 +1360,7 @@ impl TextRecord for VeryLongStringRecord {
     }
 }
 
+#[derive(Clone, Debug)]
 pub struct LongStringValueLabels {
     pub var_name: UnencodedString,
     pub width: u32,
@@ -1374,9 +1369,11 @@ pub struct LongStringValueLabels {
     pub labels: Vec<(UnencodedString, UnencodedString)>,
 }
 
-pub struct LongStringValueLabelSet(Vec<LongStringValueLabels>);
+#[derive(Clone, Debug)]
+pub struct LongStringValueLabelRecord(Vec<LongStringValueLabels>);
 
-impl ExtensionRecord for LongStringValueLabelSet {
+impl ExtensionRecord for LongStringValueLabelRecord {
+    const SUBTYPE: u32 = 21;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
     const NAME: &'static str = "long string value labels record";
@@ -1402,7 +1399,7 @@ impl ExtensionRecord for LongStringValueLabelSet {
                 labels,
             })
         }
-        Ok(LongStringValueLabelSet(label_set))
+        Ok(LongStringValueLabelRecord(label_set))
     }
 }
 
@@ -1417,6 +1414,7 @@ pub struct LongStringMissingValues {
 pub struct LongStringMissingValueSet(Vec<LongStringMissingValues>);
 
 impl ExtensionRecord for LongStringMissingValueSet {
+    const SUBTYPE: u32 = 22;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
     const NAME: &'static str = "long string missing values record";
@@ -1462,9 +1460,11 @@ impl ExtensionRecord for LongStringMissingValueSet {
     }
 }
 
-pub struct Encoding(pub String);
+#[derive(Clone, Debug)]
+pub struct EncodingRecord(pub String);
 
-impl ExtensionRecord for Encoding {
+impl ExtensionRecord for EncodingRecord {
+    const SUBTYPE: u32 = 20;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
     const NAME: &'static str = "encoding record";
@@ -1472,7 +1472,7 @@ impl ExtensionRecord for Encoding {
     fn parse(ext: &Extension, _endian: Endian, _warn: impl Fn(Error)) -> Result<Self, Error> {
         ext.check_size::<Self>()?;
 
-        Ok(Encoding(String::from_utf8(ext.data.clone()).map_err(
+        Ok(EncodingRecord(String::from_utf8(ext.data.clone()).map_err(
             |_| Error::BadEncodingName { offset: ext.offset },
         )?))
     }
@@ -1599,6 +1599,7 @@ impl TextRecord for VariableAttributeRecord {
     }
 }
 
+#[derive(Clone, Debug)]
 pub struct NumberOfCasesRecord {
     /// Always observed as 1.
     pub one: u64,
@@ -1608,6 +1609,7 @@ pub struct NumberOfCasesRecord {
 }
 
 impl ExtensionRecord for NumberOfCasesRecord {
+    const SUBTYPE: u32 = 16;
     const SIZE: Option<u32> = Some(8);
     const COUNT: Option<u32> = Some(2);
     const NAME: &'static str = "extended number of cases record";
@@ -1623,6 +1625,22 @@ impl ExtensionRecord for NumberOfCasesRecord {
     }
 }
 
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+pub enum TextExtensionSubtype {
+    VariableSets = 5,
+    ProductInfo = 10,
+    LongNames = 13,
+    LongStrings = 14,
+    FileAttributes = 17,
+    VariableAttributes = 18,
+}
+
+#[derive(Clone, Debug)]
+pub struct TextExtension {
+    pub subtype: TextExtensionSubtype,
+    pub string: UnencodedString,
+}
+
 #[derive(Clone, Debug)]
 pub struct Extension {
     /// Offset from the start of the file to the start of the record.
@@ -1694,7 +1712,7 @@ impl Extension {
         Ok(())
     }
 
-    fn read<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<Extension, Error> {
+    fn read<R: Read + Seek>(r: &mut R, endian: Endian) -> Result<Record, Error> {
         let subtype = endian.parse(read_bytes(r)?);
         let offset = r.stream_position()?;
         let size: u32 = endian.parse(read_bytes(r)?);
@@ -1709,13 +1727,29 @@ impl Extension {
         };
         let offset = r.stream_position()?;
         let data = read_vec(r, product as usize)?;
-        Ok(Extension {
+        let extension = Extension {
             offset,
             subtype,
             size,
             count,
             data,
-        })
+        };
+        match subtype {
+            IntegerInfo::SUBTYPE => Ok(Record::IntegerInfo(IntegerInfo::parse(&extension, endian, |_| ())?)),
+            FloatInfo::SUBTYPE => Ok(Record::FloatInfo(FloatInfo::parse(&extension, endian, |_| ())?)),
+            VarDisplayRecord::SUBTYPE => Ok(Record::VarDisplay(VarDisplayRecord::parse(&extension, endian, |_| ())?)),
+            MultipleResponseRecord::SUBTYPE | 19 => Ok(Record::MultipleResponse(MultipleResponseRecord::parse(&extension, endian, |_| ())?)),
+            LongStringValueLabelRecord::SUBTYPE => Ok(Record::LongStringValueLabels(LongStringValueLabelRecord::parse(&extension, endian, |_| ())?)),
+            EncodingRecord::SUBTYPE => Ok(Record::Encoding(EncodingRecord::parse(&extension, endian, |_| ())?)),
+            NumberOfCasesRecord::SUBTYPE => Ok(Record::NumberOfCases(NumberOfCasesRecord::parse(&extension, endian, |_| ())?)),
+            x if x == TextExtensionSubtype::VariableSets as u32 => Ok(Record::VariableSets(UnencodedString(extension.data))),
+            x if x == TextExtensionSubtype::ProductInfo as u32 => Ok(Record::ProductInfo(UnencodedString(extension.data))),
+            x if x == TextExtensionSubtype::LongNames as u32 => Ok(Record::LongNames(UnencodedString(extension.data))),
+            x if x == TextExtensionSubtype::LongStrings as u32 => Ok(Record::LongStrings(UnencodedString(extension.data))),
+            x if x == TextExtensionSubtype::FileAttributes as u32 => Ok(Record::FileAttributes(UnencodedString(extension.data))),
+            x if x == TextExtensionSubtype::VariableAttributes as u32 => Ok(Record::VariableAttributes(UnencodedString(extension.data))),
+            _ => Ok(Record::Extension(extension))
+        }
     }
 }