decodedrecord works
[pspp] / rust / src / raw.rs
index 65b0f9474cf8326c98af843c7f4b95e894633749..e75e5a05e8e8bfac8adba560c1d4834c79fc6d85 100644 (file)
@@ -1,5 +1,5 @@
 use crate::{
-    cooked::VarWidth,
+    dictionary::VarWidth,
     endian::{Endian, Parse, ToBytes},
     identifier::{Error as IdError, Identifier},
 };
@@ -11,7 +11,7 @@ use std::{
     borrow::Cow,
     cell::RefCell,
     cmp::Ordering,
-    collections::{VecDeque, HashMap},
+    collections::{HashMap, VecDeque},
     fmt::{Debug, Display, Formatter, Result as FmtResult},
     io::{Error as IoError, Read, Seek, SeekFrom},
     iter::repeat,
@@ -189,6 +189,9 @@ pub enum Error {
     #[error("Invalid variable name in long string missing values record.  {0}")]
     InvalidLongStringMissingValueVariableName(IdError),
 
+    #[error("Invalid variable name in long string value label record.  {0}")]
+    InvalidLongStringValueLabelName(IdError),
+
     #[error("Details TBD")]
     TBD,
 }
@@ -201,24 +204,43 @@ pub enum Record {
     Document(DocumentRecord<RawDocumentLine>),
     IntegerInfo(IntegerInfoRecord),
     FloatInfo(FloatInfoRecord),
-    VariableSets(VariableSetRecord),
     VarDisplay(VarDisplayRecord),
     MultipleResponse(MultipleResponseRecord<RawString, RawString>),
-    LongStringValueLabels(LongStringValueLabelRecord<RawString>),
+    LongStringValueLabels(LongStringValueLabelRecord<RawString, RawString>),
     LongStringMissingValues(LongStringMissingValueRecord<RawString, RawStr<8>>),
     Encoding(EncodingRecord),
     NumberOfCases(NumberOfCasesRecord),
+    Text(TextRecord),
+    OtherExtension(Extension),
+    EndOfHeaders(u32),
+    ZHeader(ZHeader),
+    ZTrailer(ZTrailer),
+    Cases(Rc<RefCell<Cases>>),
+}
+
+pub enum DecodedRecord<'a> {
+    Header(HeaderRecord<Cow<'a, str>>),
+    Variable(VariableRecord<Cow<'a, str>, String>),
+    ValueLabel(ValueLabelRecord<RawStr<8>, Cow<'a, str>>),
+    Document(DocumentRecord<Cow<'a, str>>),
+    IntegerInfo(IntegerInfoRecord),
+    FloatInfo(FloatInfoRecord),
+    VarDisplay(VarDisplayRecord),
+    MultipleResponse(MultipleResponseRecord<Identifier, Cow<'a, str>>),
+    LongStringValueLabels(LongStringValueLabelRecord<Identifier, Cow<'a, str>>),
+    LongStringMissingValues(LongStringMissingValueRecord<Identifier, String>),
+    Encoding(EncodingRecord),
+    NumberOfCases(NumberOfCasesRecord),
+    VariableSets(VariableSetRecord),
     ProductInfo(ProductInfoRecord),
     LongNames(LongNamesRecord),
     VeryLongStrings(VeryLongStringsRecord),
     FileAttributes(FileAttributeRecord),
     VariableAttributes(VariableAttributeRecord),
-    Text(TextRecord),
     OtherExtension(Extension),
     EndOfHeaders(u32),
     ZHeader(ZHeader),
     ZTrailer(ZTrailer),
-    Cases(Rc<RefCell<Cases>>),
 }
 
 impl Record {
@@ -247,7 +269,32 @@ impl Record {
         }
     }
 
-    
+    fn decode<'a>(&'a self, decoder: &Decoder) -> Result<DecodedRecord<'a>, Error> {
+        Ok(match self {
+            Record::Header(record) => record.decode(decoder),
+            Record::Variable(record) => record.decode(decoder),
+            Record::ValueLabel(record) => DecodedRecord::ValueLabel(record.decode(decoder)),
+            Record::Document(record) => record.decode(decoder),
+            Record::IntegerInfo(record) => DecodedRecord::IntegerInfo(record.clone()),
+            Record::FloatInfo(record) => DecodedRecord::FloatInfo(record.clone()),
+            Record::VarDisplay(record) => DecodedRecord::VarDisplay(record.clone()),
+            Record::MultipleResponse(record) => record.decode(decoder),
+            Record::LongStringValueLabels(record) => {
+                DecodedRecord::LongStringValueLabels(record.decode(decoder)?)
+            }
+            Record::LongStringMissingValues(record) => {
+                DecodedRecord::LongStringMissingValues(record.decode(decoder))
+            }
+            Record::Encoding(record) => DecodedRecord::Encoding(record.clone()),
+            Record::NumberOfCases(record) => DecodedRecord::NumberOfCases(record.clone()),
+            Record::Text(record) => record.decode(decoder)?,
+            Record::OtherExtension(record) => DecodedRecord::OtherExtension(record.clone()),
+            Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(record.clone()),
+            Record::ZHeader(record) => DecodedRecord::ZHeader(record.clone()),
+            Record::ZTrailer(record) => DecodedRecord::ZTrailer(record.clone()),
+            Record::Cases(_) => todo!(),
+        })
+    }
 }
 
 // If `s` is valid UTF-8, returns it decoded as UTF-8, otherwise returns it
@@ -404,12 +451,12 @@ impl HeaderRecord<RawString> {
         })
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> HeaderRecord<Cow<'a, str>> {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord<'a> {
         let eye_catcher = decoder.decode(&self.eye_catcher);
         let file_label = decoder.decode(&self.file_label);
         let creation_date = decoder.decode(&self.creation_date);
         let creation_time = decoder.decode(&self.creation_time);
-        HeaderRecord {
+        DecodedRecord::Header(HeaderRecord {
             eye_catcher,
             weight_index: self.weight_index,
             n_cases: self.n_cases,
@@ -423,7 +470,7 @@ impl HeaderRecord<RawString> {
             creation_date,
             creation_time,
             endian: self.endian,
-        }
+        })
     }
 }
 
@@ -456,7 +503,7 @@ impl Decoder {
     /// same length in bytes.
     ///
     /// XXX warn about errors?
-    fn decode_exact_length<'a>(&self, input: &'a [u8]) -> Cow<'a, str> {
+    pub fn decode_exact_length<'a>(&self, input: &'a [u8]) -> Cow<'a, str> {
         if let (s, false) = self.encoding.decode_without_bom_handling(input) {
             // This is the common case.  Usually there will be no errors.
             s
@@ -1235,8 +1282,8 @@ impl VariableRecord<RawString, RawStr<8>> {
         }))
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> VariableRecord<Cow<'a, str>, String> {
-        VariableRecord {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+        DecodedRecord::Variable(VariableRecord {
             offsets: self.offsets.clone(),
             width: self.width,
             name: decoder.decode(&self.name),
@@ -1244,7 +1291,7 @@ impl VariableRecord<RawString, RawStr<8>> {
             write_format: self.write_format,
             missing_values: self.missing_values.decode(decoder),
             label: self.label.as_ref().map(|label| decoder.decode(label)),
-        }
+        })
     }
 }
 
@@ -1485,6 +1532,23 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
             var_type,
         })))
     }
+
+    fn decode<'a>(&'a self, decoder: &Decoder) -> ValueLabelRecord<RawStr<8>, Cow<'a, str>> {
+        let labels = self
+            .labels
+            .iter()
+            .map(|ValueLabel { value, label }| ValueLabel {
+                value: value.clone(),
+                label: decoder.decode(label),
+            })
+            .collect();
+        ValueLabelRecord {
+            offsets: self.offsets.clone(),
+            labels,
+            dict_indexes: self.dict_indexes.clone(),
+            var_type: self.var_type,
+        }
+    }
 }
 
 #[derive(Clone, Debug)]
@@ -1532,15 +1596,15 @@ impl DocumentRecord<RawDocumentLine> {
         }
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DocumentRecord<Cow<'a, str>> {
-        DocumentRecord {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+        DecodedRecord::Document(DocumentRecord {
             offsets: self.offsets.clone(),
             lines: self
                 .lines
                 .iter()
                 .map(|s| decoder.decode_slice(&s.0))
                 .collect(),
-        }
+        })
     }
 }
 
@@ -1779,14 +1843,14 @@ impl ExtensionRecord for MultipleResponseRecord<RawString, RawString> {
 }
 
 impl MultipleResponseRecord<RawString, RawString> {
-    fn decode<'a>(&'a self, decoder: &Decoder) -> MultipleResponseRecord<Identifier, Cow<'a, str>> {
+    fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
         let mut sets = Vec::new();
         for set in self.0.iter() {
             if let Some(set) = set.decode(decoder).warn_on_error(&decoder.warn) {
                 sets.push(set);
             }
         }
-        MultipleResponseRecord(sets)
+        DecodedRecord::MultipleResponse(MultipleResponseRecord(sets))
     }
 }
 
@@ -2095,26 +2159,26 @@ impl TextRecord {
             text: extension.data.into(),
         }
     }
-    pub fn decode<'a>(&self, decoder: &Decoder) -> Result<Option<Record>, Error> {
+    pub fn decode<'a>(&self, decoder: &Decoder) -> Result<DecodedRecord, Error> {
         match self.rec_type {
-            TextRecordType::VariableSets => Ok(Some(Record::VariableSets(
+            TextRecordType::VariableSets => Ok(DecodedRecord::VariableSets(
                 VariableSetRecord::decode(self, decoder),
-            ))),
-            TextRecordType::ProductInfo => Ok(Some(Record::ProductInfo(
+            )),
+            TextRecordType::ProductInfo => Ok(DecodedRecord::ProductInfo(
                 ProductInfoRecord::decode(self, decoder),
-            ))),
-            TextRecordType::LongNames => Ok(Some(Record::LongNames(LongNamesRecord::decode(
+            )),
+            TextRecordType::LongNames => Ok(DecodedRecord::LongNames(LongNamesRecord::decode(
                 self, decoder,
-            )))),
-            TextRecordType::VeryLongStrings => Ok(Some(Record::VeryLongStrings(
-                VeryLongStringsRecord::decode(self, decoder),
             ))),
-            TextRecordType::FileAttributes => {
-                Ok(FileAttributeRecord::decode(self, decoder).map(|fa| Record::FileAttributes(fa)))
-            }
-            TextRecordType::VariableAttributes => Ok(Some(Record::VariableAttributes(
+            TextRecordType::VeryLongStrings => Ok(DecodedRecord::VeryLongStrings(
+                VeryLongStringsRecord::decode(self, decoder),
+            )),
+            TextRecordType::FileAttributes => Ok(DecodedRecord::FileAttributes(
+                FileAttributeRecord::decode(self, decoder),
+            )),
+            TextRecordType::VariableAttributes => Ok(DecodedRecord::VariableAttributes(
                 VariableAttributeRecord::decode(self, decoder),
-            ))),
+            )),
         }
     }
 }
@@ -2221,24 +2285,36 @@ impl AttributeSet {
     }
 }
 
+impl Default for AttributeSet {
+    fn default() -> Self {
+        Self(HashMap::default())
+    }
+}
+
 #[derive(Clone, Debug)]
 pub struct FileAttributeRecord(AttributeSet);
 
 impl FileAttributeRecord {
-    fn decode(source: &TextRecord, decoder: &Decoder) -> Option<Self> {
+    fn decode(source: &TextRecord, decoder: &Decoder) -> Self {
         let input = decoder.decode(&source.text);
         match AttributeSet::parse(decoder, &input, None).warn_on_error(&decoder.warn) {
             Some((set, rest)) => {
                 if !rest.is_empty() {
                     decoder.warn(Error::TBD);
                 }
-                Some(FileAttributeRecord(set))
+                FileAttributeRecord(set)
             }
-            None => None,
+            None => FileAttributeRecord::default(),
         }
     }
 }
 
+impl Default for FileAttributeRecord {
+    fn default() -> Self {
+        Self(AttributeSet::default())
+    }
+}
+
 #[derive(Clone, Debug)]
 pub struct VarAttributeSet {
     pub long_var_name: Identifier,
@@ -2660,23 +2736,48 @@ fn read_string<R: Read>(r: &mut R, endian: Endian) -> Result<RawString, IoError>
 }
 
 #[derive(Clone, Debug)]
-pub struct LongStringValueLabels<S>
+pub struct LongStringValueLabels<N, S>
 where
     S: Debug,
 {
-    pub var_name: S,
+    pub var_name: N,
     pub width: u32,
 
     /// `(value, label)` pairs, where each value is `width` bytes.
     pub labels: Vec<(S, S)>,
 }
 
+impl LongStringValueLabels<RawString, RawString> {
+    fn decode<'a>(
+        &'a self,
+        decoder: &Decoder,
+    ) -> Result<LongStringValueLabels<Identifier, Cow<'a, str>>, Error> {
+        let var_name = decoder.decode(&self.var_name);
+        let var_name = Identifier::new(var_name.trim_end(), decoder.encoding)
+            .map_err(Error::InvalidLongStringValueLabelName)?;
+
+        let mut labels = Vec::with_capacity(self.labels.len());
+        for (value, label) in self.labels.iter() {
+            let value = decoder.decode_exact_length(&value.0);
+            let label = decoder.decode(&label);
+            labels.push((value, label));
+        }
+
+        Ok(LongStringValueLabels {
+            var_name,
+            width: self.width,
+            labels,
+        })
+    }
+}
+
 #[derive(Clone, Debug)]
-pub struct LongStringValueLabelRecord<S>(pub Vec<LongStringValueLabels<S>>)
+pub struct LongStringValueLabelRecord<N, S>(pub Vec<LongStringValueLabels<N, S>>)
 where
+    N: Debug,
     S: Debug;
 
-impl ExtensionRecord for LongStringValueLabelRecord<RawString> {
+impl ExtensionRecord for LongStringValueLabelRecord<RawString, RawString> {
     const SUBTYPE: u32 = 21;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
@@ -2708,3 +2809,19 @@ impl ExtensionRecord for LongStringValueLabelRecord<RawString> {
         )))
     }
 }
+
+impl LongStringValueLabelRecord<RawString, RawString> {
+    fn decode<'a>(
+        &'a self,
+        decoder: &Decoder,
+    ) -> Result<LongStringValueLabelRecord<Identifier, Cow<'a, str>>, Error> {
+        let mut labels = Vec::with_capacity(self.0.len());
+        for label in &self.0 {
+            match label.decode(decoder) {
+                Ok(set) => labels.push(set),
+                Err(error) => decoder.warn(error),
+            }
+        }
+        Ok(LongStringValueLabelRecord(labels))
+    }
+}