work
[pspp] / rust / src / raw.rs
index 4ab11c3043dcf4205601df0a0c10f767ada7b07e..65b0f9474cf8326c98af843c7f4b95e894633749 100644 (file)
@@ -1,4 +1,8 @@
-use crate::endian::{Endian, Parse, ToBytes};
+use crate::{
+    cooked::VarWidth,
+    endian::{Endian, Parse, ToBytes},
+    identifier::{Error as IdError, Identifier},
+};
 
 use encoding_rs::{mem::decode_latin1, DecoderResult, Encoding};
 use flate2::read::ZlibDecoder;
@@ -7,7 +11,7 @@ use std::{
     borrow::Cow,
     cell::RefCell,
     cmp::Ordering,
-    collections::VecDeque,
+    collections::{VecDeque, HashMap},
     fmt::{Debug, Display, Formatter, Result as FmtResult},
     io::{Error as IoError, Read, Seek, SeekFrom},
     iter::repeat,
@@ -158,6 +162,33 @@ pub enum Error {
     #[error("Invalid variable display alignment value {0}")]
     InvalidAlignment(u32),
 
+    #[error("Invalid attribute name.  {0}")]
+    InvalidAttributeName(IdError),
+
+    #[error("Invalid variable name in attribute record.  {0}")]
+    InvalidAttributeVariableName(IdError),
+
+    #[error("Invalid short name in long variable name record.  {0}")]
+    InvalidShortName(IdError),
+
+    #[error("Invalid name in long variable name record.  {0}")]
+    InvalidLongName(IdError),
+
+    #[error("Invalid variable name in very long string record.  {0}")]
+    InvalidLongStringName(IdError),
+
+    #[error("Invalid variable name in variable set record.  {0}")]
+    InvalidVariableSetName(IdError),
+
+    #[error("Invalid multiple response set name.  {0}")]
+    InvalidMrSetName(IdError),
+
+    #[error("Invalid multiple response set variable name.  {0}")]
+    InvalidMrSetVariableName(IdError),
+
+    #[error("Invalid variable name in long string missing values record.  {0}")]
+    InvalidLongStringMissingValueVariableName(IdError),
+
     #[error("Details TBD")]
     TBD,
 }
@@ -172,7 +203,7 @@ pub enum Record {
     FloatInfo(FloatInfoRecord),
     VariableSets(VariableSetRecord),
     VarDisplay(VarDisplayRecord),
-    MultipleResponse(MultipleResponseRecord<RawString>),
+    MultipleResponse(MultipleResponseRecord<RawString, RawString>),
     LongStringValueLabels(LongStringValueLabelRecord<RawString>),
     LongStringMissingValues(LongStringMissingValueRecord<RawString, RawStr<8>>),
     Encoding(EncodingRecord),
@@ -215,6 +246,8 @@ impl Record {
             }),
         }
     }
+
+    
 }
 
 // If `s` is valid UTF-8, returns it decoded as UTF-8, otherwise returns it
@@ -371,7 +404,7 @@ impl HeaderRecord<RawString> {
         })
     }
 
-    fn decode<'a>(&'a self, decoder: &Decoder) -> HeaderRecord<Cow<'a, str>> {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> HeaderRecord<Cow<'a, str>> {
         let eye_catcher = decoder.decode(&self.eye_catcher);
         let file_label = decoder.decode(&self.file_label);
         let creation_date = decoder.decode(&self.creation_date);
@@ -394,9 +427,9 @@ impl HeaderRecord<RawString> {
     }
 }
 
-struct Decoder {
-    encoding: &'static Encoding,
-    warn: Box<dyn Fn(Error)>,
+pub struct Decoder {
+    pub encoding: &'static Encoding,
+    pub warn: Box<dyn Fn(Error)>,
 }
 
 impl Decoder {
@@ -451,6 +484,14 @@ impl Decoder {
             output.into()
         }
     }
+
+    pub fn decode_identifier(&self, input: &RawString) -> Result<Identifier, IdError> {
+        self.new_identifier(&self.decode(input))
+    }
+
+    pub fn new_identifier(&self, name: &str) -> Result<Identifier, IdError> {
+        Identifier::new(name, self.encoding)
+    }
 }
 
 impl<S> Header for HeaderRecord<S>
@@ -517,14 +558,14 @@ pub enum VarType {
 }
 
 impl VarType {
-    fn from_width(width: i32) -> VarType {
+    pub fn from_width(width: VarWidth) -> VarType {
         match width {
-            0 => VarType::Numeric,
-            _ => VarType::String,
+            VarWidth::Numeric => Self::Numeric,
+            VarWidth::String(_) => Self::String,
         }
     }
 
-    fn opposite(self) -> VarType {
+    pub fn opposite(self) -> VarType {
         match self {
             Self::Numeric => Self::String,
             Self::String => Self::Numeric,
@@ -813,7 +854,11 @@ where
                 };
                 match record {
                     Record::Variable(VariableRecord { width, .. }) => {
-                        self.var_types.push(VarType::from_width(width));
+                        self.var_types.push(if width == 0 {
+                            VarType::Numeric
+                        } else {
+                            VarType::String
+                        });
                     }
                     Record::EndOfHeaders(_) => {
                         self.state = if let Some(Compression::ZLib) = self.header.compression {
@@ -981,7 +1026,7 @@ fn format_name(type_: u32) -> Cow<'static, str> {
 }
 
 #[derive(Clone)]
-pub struct MissingValues<S>
+pub struct MissingValues<S = String>
 where
     S: Debug,
 {
@@ -1028,6 +1073,18 @@ where
     }
 }
 
+impl<S> Default for MissingValues<S>
+where
+    S: Debug,
+{
+    fn default() -> Self {
+        Self {
+            values: Vec::new(),
+            range: None,
+        }
+    }
+}
+
 impl MissingValues<RawStr<8>> {
     fn read<R: Read + Seek>(
         r: &mut R,
@@ -1044,7 +1101,11 @@ impl MissingValues<RawStr<8>> {
             (_, _) => return Err(Error::BadStringMissingValueCode { offset, code }),
         };
 
-        let var_type = VarType::from_width(width);
+        let var_type = if width == 0 {
+            VarType::Numeric
+        } else {
+            VarType::String
+        };
 
         let mut values = Vec::new();
         for _ in 0..n_values {
@@ -1174,7 +1235,7 @@ impl VariableRecord<RawString, RawStr<8>> {
         }))
     }
 
-    fn decode<'a>(&'a self, decoder: &Decoder) -> VariableRecord<Cow<'a, str>, String> {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> VariableRecord<Cow<'a, str>, String> {
         VariableRecord {
             offsets: self.offsets.clone(),
             width: self.width,
@@ -1471,7 +1532,7 @@ impl DocumentRecord<RawDocumentLine> {
         }
     }
 
-    fn decode<'a>(&'a self, decoder: &Decoder) -> DocumentRecord<Cow<'a, str>> {
+    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DocumentRecord<Cow<'a, str>> {
         DocumentRecord {
             offsets: self.offsets.clone(),
             lines: self
@@ -1614,17 +1675,18 @@ impl MultipleResponseType {
 }
 
 #[derive(Clone, Debug)]
-pub struct MultipleResponseSet<S>
+pub struct MultipleResponseSet<I, S>
 where
+    I: Debug,
     S: Debug,
 {
-    pub name: S,
+    pub name: I,
     pub label: S,
     pub mr_type: MultipleResponseType,
-    pub short_names: Vec<S>,
+    pub short_names: Vec<I>,
 }
 
-impl MultipleResponseSet<RawString> {
+impl MultipleResponseSet<RawString, RawString> {
     fn parse(input: &[u8]) -> Result<(Self, &[u8]), Error> {
         let Some(equals) = input.iter().position(|&b| b == b'=') else {
             return Err(Error::TBD);
@@ -1665,22 +1727,38 @@ impl MultipleResponseSet<RawString> {
         ))
     }
 
-    fn decode<'a>(&'a self, decoder: &Decoder) -> MultipleResponseSet<Cow<'a, str>> {
-        MultipleResponseSet {
-            name: decoder.decode(&self.name),
+    fn decode<'a>(
+        &'a self,
+        decoder: &Decoder,
+    ) -> Result<MultipleResponseSet<Identifier, Cow<'a, str>>, Error> {
+        let mut short_names = Vec::with_capacity(self.short_names.len());
+        for short_name in self.short_names.iter() {
+            if let Some(short_name) = decoder
+                .decode_identifier(short_name)
+                .map_err(|err| Error::InvalidMrSetName(err))
+                .warn_on_error(&decoder.warn)
+            {
+                short_names.push(short_name);
+            }
+        }
+        Ok(MultipleResponseSet {
+            name: decoder
+                .decode_identifier(&self.name)
+                .map_err(|err| Error::InvalidMrSetVariableName(err))?,
             label: decoder.decode(&self.label),
             mr_type: self.mr_type.clone(),
-            short_names: self.short_names.iter().map(|s| decoder.decode(s)).collect(),
-        }
+            short_names: short_names,
+        })
     }
 }
 
 #[derive(Clone, Debug)]
-pub struct MultipleResponseRecord<S>(pub Vec<MultipleResponseSet<S>>)
+pub struct MultipleResponseRecord<I, S>(pub Vec<MultipleResponseSet<I, S>>)
 where
+    I: Debug,
     S: Debug;
 
-impl ExtensionRecord for MultipleResponseRecord<RawString> {
+impl ExtensionRecord for MultipleResponseRecord<RawString, RawString> {
     const SUBTYPE: u32 = 7;
     const SIZE: Option<u32> = Some(1);
     const COUNT: Option<u32> = None;
@@ -1700,9 +1778,15 @@ impl ExtensionRecord for MultipleResponseRecord<RawString> {
     }
 }
 
-impl MultipleResponseRecord<RawString> {
-    fn decode<'a>(&'a self, decoder: &Decoder) -> MultipleResponseRecord<Cow<'a, str>> {
-        MultipleResponseRecord(self.0.iter().map(|set| set.decode(decoder)).collect())
+impl MultipleResponseRecord<RawString, RawString> {
+    fn decode<'a>(&'a self, decoder: &Decoder) -> MultipleResponseRecord<Identifier, Cow<'a, str>> {
+        let mut sets = Vec::new();
+        for set in self.0.iter() {
+            if let Some(set) = set.decode(decoder).warn_on_error(&decoder.warn) {
+                sets.push(set);
+            }
+        }
+        MultipleResponseRecord(sets)
     }
 }
 
@@ -1734,6 +1818,13 @@ pub enum Measure {
 }
 
 impl Measure {
+    pub fn default_for_type(var_type: VarType) -> Option<Measure> {
+        match var_type {
+            VarType::Numeric => None,
+            VarType::String => Some(Self::Nominal),
+        }
+    }
+
     fn try_decode(source: u32) -> Result<Option<Measure>, Error> {
         match source {
             0 => Ok(None),
@@ -1762,6 +1853,13 @@ impl Alignment {
             _ => Err(Error::InvalidAlignment(source)),
         }
     }
+
+    pub fn default_for_type(var_type: VarType) -> Self {
+        match var_type {
+            VarType::Numeric => Self::Right,
+            VarType::String => Self::Left,
+        }
+    }
 }
 
 #[derive(Clone, Debug)]
@@ -1834,11 +1932,14 @@ where
 }
 
 impl LongStringMissingValues<RawString, RawStr<8>> {
-    fn decode<'a>(&self, decoder: &Decoder) -> LongStringMissingValues<String, String> {
-        LongStringMissingValues {
-            var_name: decoder.decode(&self.var_name).to_string(),
+    fn decode<'a>(
+        &self,
+        decoder: &Decoder,
+    ) -> Result<LongStringMissingValues<Identifier, String>, IdError> {
+        Ok(LongStringMissingValues {
+            var_name: decoder.decode_identifier(&self.var_name)?,
             missing_values: self.missing_values.decode(decoder),
-        }
+        })
     }
 }
 
@@ -1901,8 +2002,21 @@ impl ExtensionRecord for LongStringMissingValueRecord<RawString, RawStr<8>> {
 }
 
 impl LongStringMissingValueRecord<RawString, RawStr<8>> {
-    fn decode<'a>(&self, decoder: &Decoder) -> LongStringMissingValueRecord<String, String> {
-        LongStringMissingValueRecord(self.0.iter().map(|mv| mv.decode(decoder)).collect())
+    pub fn decode<'a>(
+        &self,
+        decoder: &Decoder,
+    ) -> LongStringMissingValueRecord<Identifier, String> {
+        let mut mvs = Vec::with_capacity(self.0.len());
+        for mv in self.0.iter() {
+            if let Some(mv) = mv
+                .decode(decoder)
+                .map_err(|err| Error::InvalidLongStringMissingValueVariableName(err))
+                .warn_on_error(&decoder.warn)
+            {
+                mvs.push(mv);
+            }
+        }
+        LongStringMissingValueRecord(mvs)
     }
 }
 
@@ -1981,7 +2095,7 @@ impl TextRecord {
             text: extension.data.into(),
         }
     }
-    fn decode<'a>(&self, decoder: &Decoder) -> Result<Option<Record>, Error> {
+    pub fn decode<'a>(&self, decoder: &Decoder) -> Result<Option<Record>, Error> {
         match self.rec_type {
             TextRecordType::VariableSets => Ok(Some(Record::VariableSets(
                 VariableSetRecord::decode(self, decoder),
@@ -1998,17 +2112,16 @@ impl TextRecord {
             TextRecordType::FileAttributes => {
                 Ok(FileAttributeRecord::decode(self, decoder).map(|fa| Record::FileAttributes(fa)))
             }
-            TextRecordType::VariableAttributes => {
-                Ok(Some(Record::VariableAttributes(
-VariableAttributeRecord::decode(self, decoder))))
-            }
+            TextRecordType::VariableAttributes => Ok(Some(Record::VariableAttributes(
+                VariableAttributeRecord::decode(self, decoder),
+            ))),
         }
     }
 }
 
 #[derive(Clone, Debug)]
 pub struct VeryLongString {
-    pub short_name: String,
+    pub short_name: Identifier,
     pub length: u16,
 }
 
@@ -2017,17 +2130,37 @@ impl VeryLongString {
         let Some((short_name, length)) = input.split_once('=') else {
             return Err(Error::TBD);
         };
+        let short_name = decoder
+            .new_identifier(short_name)
+            .map_err(Error::InvalidLongStringName)?;
         let length = length.parse().map_err(|_| Error::TBD)?;
-        Ok(VeryLongString {
-            short_name: short_name.into(),
-            length,
-        })
+        Ok(VeryLongString { short_name, length })
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct VeryLongStringsRecord(Vec<VeryLongString>);
+
+impl VeryLongStringsRecord {
+    fn decode(source: &TextRecord, decoder: &Decoder) -> Self {
+        let input = decoder.decode(&source.text);
+        let mut very_long_strings = Vec::new();
+        for tuple in input
+            .split('\0')
+            .map(|s| s.trim_end_matches('\t'))
+            .filter(|s| !s.is_empty())
+        {
+            if let Some(vls) = VeryLongString::parse(decoder, tuple).warn_on_error(&decoder.warn) {
+                very_long_strings.push(vls)
+            }
+        }
+        VeryLongStringsRecord(very_long_strings)
     }
 }
 
 #[derive(Clone, Debug)]
 pub struct Attribute {
-    pub name: String,
+    pub name: Identifier,
     pub values: Vec<String>,
 }
 
@@ -2036,6 +2169,9 @@ impl Attribute {
         let Some((name, mut input)) = input.split_once('(') else {
             return Err(Error::TBD);
         };
+        let name = decoder
+            .new_identifier(name)
+            .map_err(Error::InvalidAttributeName)?;
         let mut values = Vec::new();
         loop {
             let Some((value, rest)) = input.split_once('\n') else {
@@ -2051,10 +2187,7 @@ impl Attribute {
                 values.push(value.into());
             }
             if let Some(rest) = rest.strip_prefix(')') {
-                let attribute = Attribute {
-                    name: name.into(),
-                    values,
-                };
+                let attribute = Attribute { name, values };
                 return Ok((attribute, rest));
             };
             input = rest;
@@ -2063,7 +2196,7 @@ impl Attribute {
 }
 
 #[derive(Clone, Debug)]
-pub struct AttributeSet(pub Vec<Attribute>);
+pub struct AttributeSet(pub HashMap<Identifier, Vec<String>>);
 
 impl AttributeSet {
     fn parse<'a>(
@@ -2071,14 +2204,15 @@ impl AttributeSet {
         mut input: &'a str,
         sentinel: Option<char>,
     ) -> Result<(AttributeSet, &'a str), Error> {
-        let mut attributes = Vec::new();
+        let mut attributes = HashMap::new();
         let rest = loop {
             match input.chars().next() {
                 None => break input,
                 c if c == sentinel => break &input[1..],
                 _ => {
                     let (attribute, rest) = Attribute::parse(decoder, input)?;
-                    attributes.push(attribute);
+                    // XXX report duplicate name
+                    attributes.insert(attribute.name, attribute.values);
                     input = rest;
                 }
             }
@@ -2107,7 +2241,7 @@ impl FileAttributeRecord {
 
 #[derive(Clone, Debug)]
 pub struct VarAttributeSet {
-    pub long_var_name: String,
+    pub long_var_name: Identifier,
     pub attributes: AttributeSet,
 }
 
@@ -2116,9 +2250,12 @@ impl VarAttributeSet {
         let Some((long_var_name, rest)) = input.split_once(':') else {
             return Err(Error::TBD);
         };
+        let long_var_name = decoder
+            .new_identifier(long_var_name)
+            .map_err(Error::InvalidAttributeVariableName)?;
         let (attributes, rest) = AttributeSet::parse(decoder, rest, Some('/'))?;
         let var_attribute = VarAttributeSet {
-            long_var_name: long_var_name.into(),
+            long_var_name,
             attributes,
         };
         Ok((var_attribute, rest))
@@ -2147,29 +2284,27 @@ impl VariableAttributeRecord {
 }
 
 #[derive(Clone, Debug)]
-pub struct VeryLongStringsRecord(Vec<VeryLongString>);
-
-impl VeryLongStringsRecord {
-    fn decode(source: &TextRecord, decoder: &Decoder) -> Self {
-        let input = decoder.decode(&source.text);
-        let mut very_long_strings = Vec::new();
-        for tuple in input
-            .split('\0')
-            .map(|s| s.trim_end_matches('\t'))
-            .filter(|s| !s.is_empty())
-        {
-            if let Some(vls) = VeryLongString::parse(decoder, tuple).warn_on_error(&decoder.warn) {
-                very_long_strings.push(vls)
-            }
-        }
-        VeryLongStringsRecord(very_long_strings)
-    }
+pub struct LongName {
+    pub short_name: Identifier,
+    pub long_name: Identifier,
 }
 
-#[derive(Clone, Debug)]
-pub struct LongName {
-    pub short_name: String,
-    pub long_name: String,
+impl LongName {
+    fn parse(input: &str, decoder: &Decoder) -> Result<Self, Error> {
+        let Some((short_name, long_name)) = input.split_once('=') else {
+            return Err(Error::TBD);
+        };
+        let short_name = decoder
+            .new_identifier(short_name)
+            .map_err(Error::InvalidShortName)?;
+        let long_name = decoder
+            .new_identifier(long_name)
+            .map_err(Error::InvalidLongName)?;
+        Ok(LongName {
+            short_name,
+            long_name,
+        })
+    }
 }
 
 #[derive(Clone, Debug)]
@@ -2180,13 +2315,8 @@ impl LongNamesRecord {
         let input = decoder.decode(&source.text);
         let mut names = Vec::new();
         for pair in input.split('\t').filter(|s| !s.is_empty()) {
-            if let Some((short_name, long_name)) = pair.split_once('=') {
-                names.push(LongName {
-                    short_name: short_name.into(),
-                    long_name: long_name.into(),
-                });
-            } else {
-                decoder.warn(Error::TBD)
+            if let Some(long_name) = LongName::parse(pair, decoder).warn_on_error(&decoder.warn) {
+                names.push(long_name);
             }
         }
         LongNamesRecord(names)
@@ -2205,13 +2335,22 @@ impl ProductInfoRecord {
 #[derive(Clone, Debug)]
 pub struct VariableSet {
     pub name: String,
-    pub vars: Vec<String>,
+    pub vars: Vec<Identifier>,
 }
 
 impl VariableSet {
-    fn parse(input: &str) -> Result<Self, Error> {
+    fn parse(input: &str, decoder: &Decoder) -> Result<Self, Error> {
         let (name, input) = input.split_once('=').ok_or(Error::TBD)?;
-        let vars = input.split_ascii_whitespace().map(String::from).collect();
+        let mut vars = Vec::new();
+        for var in input.split_ascii_whitespace() {
+            if let Some(identifier) = decoder
+                .new_identifier(var)
+                .map_err(Error::InvalidVariableSetName)
+                .warn_on_error(&decoder.warn)
+            {
+                vars.push(identifier);
+            }
+        }
         Ok(VariableSet {
             name: name.into(),
             vars,
@@ -2230,7 +2369,7 @@ impl VariableSetRecord {
         let mut sets = Vec::new();
         let input = decoder.decode(&source.text);
         for line in input.lines() {
-            if let Some(set) = VariableSet::parse(line).warn_on_error(&decoder.warn) {
+            if let Some(set) = VariableSet::parse(line, decoder).warn_on_error(&decoder.warn) {
                 sets.push(set)
             }
         }