Works for at least one test file now
[pspp] / rust / src / cooked.rs
index 1749ecc301a0bab8e741ab613c22155276267b66..2e67965e41ff730f7fed916b22cc9d3f31cedc1a 100644 (file)
@@ -1,7 +1,7 @@
 use std::{borrow::Cow, cmp::Ordering, collections::HashMap, iter::repeat};
 
 use crate::{
-    encoding::{get_encoding, Error as EncodingError, default_encoding},
+    encoding::{default_encoding, get_encoding, Error as EncodingError},
     endian::Endian,
     format::{Error as FormatError, Spec, UncheckedSpec},
     identifier::{Error as IdError, Identifier},
@@ -144,6 +144,16 @@ pub enum Error {
     #[error("Invalid variable name in attribute record.  {0}")]
     InvalidAttributeVariableName(IdError),
 
+    // XXX This is risky because `text` might be arbitarily long.
+    #[error("Text string contains invalid bytes for {encoding} encoding: {text}")]
+    MalformedString { encoding: String, text: String },
+
+    #[error("Invalid variable measurement level value {0}")]
+    InvalidMeasurement(u32),
+
+    #[error("Invalid variable display alignment value {0}")]
+    InvalidAlignment(u32),
+
     #[error("Details TBD")]
     TBD,
 }
@@ -199,7 +209,11 @@ pub struct Decoder {
     n_generated_names: usize,
 }
 
-pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Result<Vec<Record>, Error> {
+pub fn decode(
+    headers: Vec<raw::Record>,
+    encoding: Option<&'static Encoding>,
+    warn: &impl Fn(Error),
+) -> Result<Vec<Record>, Error> {
     let Some(header_record) = headers.iter().find_map(|rec| {
         if let raw::Record::Header(header) = rec {
             Some(header)
@@ -209,31 +223,36 @@ pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Result<Vec
     }) else {
         return Err(Error::MissingHeaderRecord);
     };
-    let encoding = headers.iter().find_map(|rec| {
-        if let raw::Record::Encoding(ref e) = rec {
-            Some(e.0.as_str())
-        } else {
-            None
-        }
-    });
-    let character_code = headers.iter().find_map(|rec| {
-        if let raw::Record::IntegerInfo(ref r) = rec {
-            Some(r.character_code)
-        } else {
-            None
-        }
-    });
-    let encoding = match get_encoding(encoding, character_code) {
-        Ok(encoding) => encoding,
-        Err(err @ EncodingError::Ebcdic) => return Err(Error::EncodingError(err)),
-        Err(err) => {
-            warn(Error::EncodingError(err));
-            // Warn that we're using the default encoding.
-            default_encoding()
+    let encoding = match encoding {
+        Some(encoding) => encoding,
+        None => {
+            let encoding = headers.iter().find_map(|rec| {
+                if let raw::Record::Encoding(ref e) = rec {
+                    Some(e.0.as_str())
+                } else {
+                    None
+                }
+            });
+            let character_code = headers.iter().find_map(|rec| {
+                if let raw::Record::IntegerInfo(ref r) = rec {
+                    Some(r.character_code)
+                } else {
+                    None
+                }
+            });
+            match get_encoding(encoding, character_code) {
+                Ok(encoding) => encoding,
+                Err(err @ EncodingError::Ebcdic) => return Err(Error::EncodingError(err)),
+                Err(err) => {
+                    warn(Error::EncodingError(err));
+                    // Warn that we're using the default encoding.
+                    default_encoding()
+                }
+            }
         }
     };
 
-    let decoder = Decoder {
+    let mut decoder = Decoder {
         compression: header_record.compression,
         endian: header_record.endian,
         encoding,
@@ -243,7 +262,99 @@ pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Result<Vec
         n_generated_names: 0,
     };
 
-    unreachable!()
+    let mut output = Vec::with_capacity(headers.len());
+    for header in &headers {
+        match header {
+            raw::Record::Header(ref input) => {
+                if let Some(header) = HeaderRecord::try_decode(&mut decoder, input, warn)? {
+                    output.push(Record::Header(header))
+                }
+            }
+            raw::Record::Variable(ref input) => {
+                if let Some(variable) = VariableRecord::try_decode(&mut decoder, input, warn)? {
+                    output.push(Record::Variable(variable));
+                }
+            }
+            raw::Record::ValueLabel(ref input) => {
+                if let Some(value_label) = ValueLabelRecord::try_decode(&mut decoder, input, warn)?
+                {
+                    output.push(Record::ValueLabel(value_label));
+                }
+            }
+            raw::Record::Document(ref input) => {
+                if let Some(document) = DocumentRecord::try_decode(&mut decoder, input, warn)? {
+                    output.push(Record::Document(document))
+                }
+            }
+            raw::Record::IntegerInfo(ref input) => output.push(Record::IntegerInfo(input.clone())),
+            raw::Record::FloatInfo(ref input) => output.push(Record::FloatInfo(input.clone())),
+            raw::Record::VariableSets(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::VariableSets(VariableSetRecord::parse(&s, warn)?));
+            }
+            raw::Record::VarDisplay(ref input) => {
+                if let Some(vdr) = VarDisplayRecord::try_decode(&mut decoder, input, warn)? {
+                    output.push(Record::VarDisplay(vdr))
+                }
+            }
+            raw::Record::MultipleResponse(ref input) => {
+                if let Some(mrr) = MultipleResponseRecord::try_decode(&mut decoder, input, warn)? {
+                    output.push(Record::MultipleResponse(mrr))
+                }
+            }
+            raw::Record::LongStringValueLabels(ref input) => {
+                if let Some(mrr) =
+                    LongStringValueLabelRecord::try_decode(&mut decoder, input, warn)?
+                {
+                    output.push(Record::LongStringValueLabels(mrr))
+                }
+            }
+            raw::Record::Encoding(ref input) => output.push(Record::Encoding(input.clone())),
+            raw::Record::NumberOfCases(ref input) => {
+                output.push(Record::NumberOfCases(input.clone()))
+            }
+            raw::Record::ProductInfo(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::ProductInfo(ProductInfoRecord::parse(&s, warn)?));
+            }
+            raw::Record::LongNames(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::LongNames(LongNameRecord::parse(
+                    &mut decoder,
+                    &s,
+                    warn,
+                )?));
+            }
+            raw::Record::VeryLongStrings(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::VeryLongStrings(VeryLongStringRecord::parse(
+                    &mut decoder,
+                    &s,
+                    warn,
+                )?));
+            }
+            raw::Record::FileAttributes(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::FileAttributes(FileAttributeRecord::parse(
+                    &decoder, &s, warn,
+                )?));
+            }
+            raw::Record::VariableAttributes(ref input) => {
+                let s = decoder.decode_string_cow(&input.text.0, warn);
+                output.push(Record::VariableAttributes(VariableAttributeRecord::parse(
+                    &decoder, &s, warn,
+                )?));
+            }
+            raw::Record::OtherExtension(ref input) => {
+                output.push(Record::OtherExtension(input.clone()))
+            }
+            raw::Record::EndOfHeaders(_) => (),
+            raw::Record::ZHeader(_) => (),
+            raw::Record::ZTrailer(_) => (),
+            raw::Record::Case(_) => (),
+        };
+    }
+    Ok(output)
 }
 
 impl Decoder {
@@ -261,7 +372,10 @@ impl Decoder {
     fn decode_string_cow<'a>(&self, input: &'a [u8], warn: &impl Fn(Error)) -> Cow<'a, str> {
         let (output, malformed) = self.encoding.decode_without_bom_handling(input);
         if malformed {
-            warn(Error::TBD);
+            warn(Error::MalformedString {
+                encoding: self.encoding.name().into(),
+                text: output.clone().into(),
+            });
         }
         output
     }
@@ -277,14 +391,14 @@ impl Decoder {
         Identifier::new(&s, self.encoding)
     }
     fn get_var_by_index(&self, dict_index: usize) -> Result<&Variable, Error> {
-        let max_index = self.n_dict_indexes - 1;
-        if dict_index == 0 || dict_index as usize > max_index {
+        let max_index = self.n_dict_indexes;
+        if dict_index == 0 || dict_index > max_index {
             return Err(Error::InvalidDictIndex {
                 dict_index,
                 max_index,
             });
         }
-        let Some(variable) = self.variables.get(&dict_index) else {
+        let Some(variable) = self.variables.get(&(dict_index - 1)) else {
             return Err(Error::DictIndexIsContinuation(dict_index));
         };
         Ok(variable)
@@ -328,10 +442,10 @@ impl Decoder {
 pub trait TryDecode: Sized {
     type Input;
     fn try_decode(
-        decoder: &Decoder,
+        decoder: &mut Decoder,
         input: &Self::Input,
         warn: impl Fn(Error),
-    ) -> Result<Self, Error>;
+    ) -> Result<Option<Self>, Error>;
 }
 
 pub trait Decode<Input>: Sized {
@@ -353,23 +467,29 @@ pub struct HeaderRecord {
     pub file_label: String,
 }
 
+fn trim_end_spaces(mut s: String) -> String {
+    s.truncate(s.trim_end_matches(' ').len());
+    s
+}
+
 impl TryDecode for HeaderRecord {
     type Input = crate::raw::HeaderRecord;
 
     fn try_decode(
-        decoder: &Decoder,
+        decoder: &mut Decoder,
         input: &Self::Input,
         warn: impl Fn(Error),
-    ) -> Result<Self, Error> {
-        let eye_catcher = decoder.decode_string(&input.eye_catcher.0, &warn);
-        let file_label = decoder.decode_string(&input.file_label.0, &warn);
+    ) -> Result<Option<Self>, Error> {
+        let eye_catcher = trim_end_spaces(decoder.decode_string(&input.eye_catcher.0, &warn));
+        let file_label = trim_end_spaces(decoder.decode_string(&input.file_label.0, &warn));
         let creation_date = decoder.decode_string_cow(&input.creation_date.0, &warn);
-        let creation_date = NaiveDate::parse_from_str(&creation_date, "%v").unwrap_or_else(|_| {
-            warn(Error::InvalidCreationDate {
-                creation_date: creation_date.into(),
+        let creation_date =
+            NaiveDate::parse_from_str(&creation_date, "%e %b %Y").unwrap_or_else(|_| {
+                warn(Error::InvalidCreationDate {
+                    creation_date: creation_date.into(),
+                });
+                Default::default()
             });
-            Default::default()
-        });
         let creation_time = decoder.decode_string_cow(&input.creation_time.0, &warn);
         let creation_time =
             NaiveTime::parse_from_str(&creation_time, "%H:%M:%S").unwrap_or_else(|_| {
@@ -378,13 +498,13 @@ impl TryDecode for HeaderRecord {
                 });
                 Default::default()
             });
-        Ok(HeaderRecord {
+        Ok(Some(HeaderRecord {
             eye_catcher,
             weight_index: input.weight_index.map(|n| n as usize),
             n_cases: input.n_cases.map(|n| n as u64),
             creation: NaiveDateTime::new(creation_date, creation_time),
             file_label,
-        })
+        }))
     }
 }
 
@@ -473,8 +593,10 @@ fn decode_format(raw: raw::Spec, width: VarWidth, warn: impl Fn(Spec, FormatErro
         })
 }
 
-impl VariableRecord {
-    pub fn decode(
+impl TryDecode for VariableRecord {
+    type Input = raw::VariableRecord;
+
+    fn try_decode(
         decoder: &mut Decoder,
         input: &crate::raw::VariableRecord,
         warn: impl Fn(Error),
@@ -490,7 +612,8 @@ impl VariableRecord {
                 })
             }
         };
-        let name = match decoder.decode_identifier(&input.name.0, &warn) {
+        let name = trim_end_spaces(decoder.decode_string(&input.name.0, &warn));
+        let name = match Identifier::new(&name, decoder.encoding) {
             Ok(name) => {
                 if !decoder.var_names.contains_key(&name) {
                     name
@@ -564,17 +687,17 @@ impl TryDecode for DocumentRecord {
     type Input = crate::raw::DocumentRecord;
 
     fn try_decode(
-        decoder: &Decoder,
+        decoder: &mut Decoder,
         input: &Self::Input,
         warn: impl Fn(Error),
-    ) -> Result<Self, Error> {
-        Ok(DocumentRecord(
+    ) -> Result<Option<Self>, Error> {
+        Ok(Some(DocumentRecord(
             input
                 .lines
                 .iter()
-                .map(|s| decoder.decode_string(&s.0, &warn))
+                .map(|s| trim_end_spaces(decoder.decode_string(&s.0, &warn)))
                 .collect(),
-        ))
+        )))
     }
 }
 
@@ -646,14 +769,14 @@ pub struct ValueLabelRecord {
     pub variables: Vec<Identifier>,
 }
 
-impl ValueLabelRecord {
-    pub fn decode(
+impl TryDecode for ValueLabelRecord {
+    type Input = crate::raw::ValueLabelRecord;
+    fn try_decode(
         decoder: &mut Decoder,
-        raw_value_label: &crate::raw::ValueLabelRecord,
-        dict_indexes: &crate::raw::VarIndexRecord,
+        input: &Self::Input,
         warn: impl Fn(Error),
     ) -> Result<Option<ValueLabelRecord>, Error> {
-        let variables: Vec<&Variable> = dict_indexes
+        let variables: Vec<&Variable> = input
             .dict_indexes
             .iter()
             .filter_map(|&dict_index| {
@@ -690,7 +813,7 @@ impl ValueLabelRecord {
                 return Ok(None);
             }
         }
-        let labels = raw_value_label
+        let labels = input
             .labels
             .iter()
             .map(|(value, label)| {
@@ -960,6 +1083,18 @@ pub enum Measure {
     Scale,
 }
 
+impl Measure {
+    fn try_decode(source: u32) -> Result<Option<Measure>, Error> {
+        match source {
+            0 => Ok(None),
+            1 => Ok(Some(Measure::Nominal)),
+            2 => Ok(Some(Measure::Ordinal)),
+            3 => Ok(Some(Measure::Scale)),
+            _ => Err(Error::InvalidMeasurement(source)),
+        }
+    }
+}
+
 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 pub enum Alignment {
     Left,
@@ -967,16 +1102,67 @@ pub enum Alignment {
     Center,
 }
 
+impl Alignment {
+    fn try_decode(source: u32) -> Result<Option<Alignment>, Error> {
+        match source {
+            0 => Ok(None),
+            1 => Ok(Some(Alignment::Left)),
+            2 => Ok(Some(Alignment::Right)),
+            3 => Ok(Some(Alignment::Center)),
+            _ => Err(Error::InvalidAlignment(source)),
+        }
+    }
+}
+
 #[derive(Clone, Debug)]
 pub struct VarDisplay {
     pub measure: Option<Measure>,
-    pub width: u32,
-    pub align: Option<Alignment>,
+    pub width: Option<u32>,
+    pub alignment: Option<Alignment>,
 }
 
 #[derive(Clone, Debug)]
 pub struct VarDisplayRecord(pub Vec<VarDisplay>);
 
+impl TryDecode for VarDisplayRecord {
+    type Input = raw::VarDisplayRecord;
+    fn try_decode(
+        decoder: &mut Decoder,
+        input: &Self::Input,
+        warn: impl Fn(Error),
+    ) -> Result<Option<Self>, Error> {
+        let n_vars = decoder.variables.len();
+        let n_per_var = if input.0.len() == 3 * n_vars {
+            3
+        } else if input.0.len() == 2 * n_vars {
+            2
+        } else {
+            return Err(Error::TBD);
+        };
+
+        let var_displays = input
+            .0
+            .chunks(n_per_var)
+            .map(|chunk| {
+                let (measure, width, alignment) = match n_per_var == 3 {
+                    true => (chunk[0], Some(chunk[1]), chunk[2]),
+                    false => (chunk[0], None, chunk[1]),
+                };
+                let measure = Measure::try_decode(measure).warn_on_error(&warn).flatten();
+                let alignment = Alignment::try_decode(alignment)
+                    .warn_on_error(&warn)
+                    .flatten();
+                VarDisplay {
+                    measure,
+                    width,
+                    alignment,
+                }
+            })
+            .collect();
+        Ok(Some(VarDisplayRecord(var_displays)))
+    }
+}
+
 #[derive(Clone, Debug)]
 pub enum MultipleResponseType {
     MultipleDichotomy {
@@ -1109,10 +1295,10 @@ impl TryDecode for MultipleResponseRecord {
     type Input = raw::MultipleResponseRecord;
 
     fn try_decode(
-        decoder: &Decoder,
+        decoder: &mut Decoder,
         input: &Self::Input,
         warn: impl Fn(Error),
-    ) -> Result<Self, Error> {
+    ) -> Result<Option<Self>, Error> {
         let mut sets = Vec::with_capacity(input.0.len());
         for set in &input.0 {
             match MultipleResponseSet::decode(decoder, set, &warn) {
@@ -1120,7 +1306,7 @@ impl TryDecode for MultipleResponseRecord {
                 Err(error) => warn(error),
             }
         }
-        Ok(MultipleResponseRecord(sets))
+        Ok(Some(MultipleResponseRecord(sets)))
     }
 }
 
@@ -1137,8 +1323,8 @@ impl LongStringValueLabels {
         input: &raw::LongStringValueLabels,
         warn: &impl Fn(Error),
     ) -> Result<Self, Error> {
-        let var_name = decoder
-            .decode_identifier(&input.var_name.0, warn)
+        let var_name = decoder.decode_string(&input.var_name.0, warn);
+        let var_name = Identifier::new(var_name.trim_end(), decoder.encoding)
             .map_err(|e| Error::InvalidLongStringValueLabelName(e))?;
 
         let min_width = 9;
@@ -1175,10 +1361,10 @@ impl TryDecode for LongStringValueLabelRecord {
     type Input = raw::LongStringValueLabelRecord;
 
     fn try_decode(
-        decoder: &Decoder,
+        decoder: &mut Decoder,
         input: &Self::Input,
         warn: impl Fn(Error),
-    ) -> Result<Self, Error> {
+    ) -> Result<Option<Self>, Error> {
         let mut labels = Vec::with_capacity(input.0.len());
         for label in &input.0 {
             match LongStringValueLabels::decode(decoder, label, &warn) {
@@ -1186,7 +1372,7 @@ impl TryDecode for LongStringValueLabelRecord {
                 Err(error) => warn(error),
             }
         }
-        Ok(LongStringValueLabelRecord(labels))
+        Ok(Some(LongStringValueLabelRecord(labels)))
     }
 }