progress! and clippy!
[pspp] / rust / src / raw.rs
index a38422950ad9fc2d1c8d971756a9c33829b89cb1..1a760a0517dc26168b6084be90266fb7227f8007 100644 (file)
@@ -1,5 +1,6 @@
 use crate::{
     dictionary::VarWidth,
+    encoding::{default_encoding, get_encoding, Error as EncodingError},
     endian::{Endian, Parse, ToBytes},
     identifier::{Error as IdError, Identifier},
 };
@@ -105,6 +106,9 @@ pub enum Error {
         expected_n_blocks: u64,
         ztrailer_len: u64,
     },
+
+    #[error("{0}")]
+    EncodingError(EncodingError),
 }
 
 #[derive(ThisError, Debug)]
@@ -195,6 +199,9 @@ pub enum Warning {
     #[error("Invalid variable name in long string value label record.  {0}")]
     InvalidLongStringValueLabelName(IdError),
 
+    #[error("{0}")]
+    EncodingError(EncodingError),
+
     #[error("Details TBD")]
     TBD,
 }
@@ -227,6 +234,7 @@ pub enum Record {
     Cases(Rc<RefCell<Cases>>),
 }
 
+#[derive(Clone, Debug)]
 pub enum DecodedRecord<'a> {
     Header(HeaderRecord<Cow<'a, str>>),
     Variable(VariableRecord<Cow<'a, str>, String>),
@@ -257,7 +265,7 @@ impl Record {
         reader: &mut R,
         endian: Endian,
         var_types: &[VarType],
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error>
     where
         R: Read + Seek,
@@ -298,7 +306,7 @@ impl Record {
             Record::NumberOfCases(record) => DecodedRecord::NumberOfCases(record.clone()),
             Record::Text(record) => record.decode(decoder),
             Record::OtherExtension(record) => DecodedRecord::OtherExtension(record.clone()),
-            Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(record.clone()),
+            Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(*record),
             Record::ZHeader(record) => DecodedRecord::ZHeader(record.clone()),
             Record::ZTrailer(record) => DecodedRecord::ZTrailer(record.clone()),
             Record::Cases(_) => todo!(),
@@ -306,6 +314,32 @@ impl Record {
     }
 }
 
+pub fn encoding_from_headers(
+    headers: &Vec<Record>,
+    warn: &impl Fn(Warning),
+) -> Result<&'static Encoding, Error> {
+    let mut encoding_record = None;
+    let mut integer_info_record = None;
+    for record in headers {
+        match record {
+            Record::Encoding(record) => encoding_record = Some(record),
+            Record::IntegerInfo(record) => integer_info_record = Some(record),
+            _ => (),
+        }
+    }
+    let encoding = encoding_record.map(|record| record.0.as_str());
+    let character_code = integer_info_record.map(|record| record.character_code);
+    match get_encoding(encoding, character_code) {
+        Ok(encoding) => Ok(encoding),
+        Err(err @ EncodingError::Ebcdic) => Err(Error::EncodingError(err)),
+        Err(err) => {
+            warn(Warning::EncodingError(err));
+            // Warn that we're using the default encoding.
+            Ok(default_encoding())
+        }
+    }
+}
+
 // If `s` is valid UTF-8, returns it decoded as UTF-8, otherwise returns it
 // decoded as Latin-1 (actually bytes interpreted as Unicode code points).
 fn default_decode(s: &[u8]) -> Cow<str> {
@@ -489,6 +523,15 @@ pub struct Decoder {
 }
 
 impl Decoder {
+    pub fn new<F>(encoding: &'static Encoding, warn: F) -> Self
+    where
+        F: Fn(Warning) + 'static,
+    {
+        Self {
+            encoding,
+            warn: Box::new(warn),
+        }
+    }
     fn warn(&self, warning: Warning) {
         (self.warn)(warning)
     }
@@ -881,15 +924,7 @@ where
             &self.header,
         )
     }
-}
-
-impl<R> Iterator for Reader<R>
-where
-    R: Read + Seek + 'static,
-{
-    type Item = Result<Record, Error>;
-
-    fn next(&mut self) -> Option<Self::Item> {
+    fn _next(&mut self) -> Option<<Self as Iterator>::Item> {
         match self.state {
             ReaderState::Start => {
                 self.state = ReaderState::Headers;
@@ -960,6 +995,21 @@ where
     }
 }
 
+impl<R> Iterator for Reader<R>
+where
+    R: Read + Seek + 'static,
+{
+    type Item = Result<Record, Error>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let retval = self._next();
+        if matches!(retval, Some(Err(_))) {
+            self.state = ReaderState::End;
+        }
+        retval
+    }
+}
+
 trait ReadSeek: Read + Seek {}
 impl<T> ReadSeek for T where T: Read + Seek {}
 
@@ -1176,7 +1226,7 @@ impl MissingValues<RawStr<8>> {
         };
         Ok(Self { values, range })
     }
-    fn decode<'a>(&'a self, decoder: &Decoder) -> MissingValues<String> {
+    fn decode(&self, decoder: &Decoder) -> MissingValues<String> {
         MissingValues {
             values: self
                 .values
@@ -1291,7 +1341,7 @@ impl VariableRecord<RawString, RawStr<8>> {
         }))
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         DecodedRecord::Variable(VariableRecord {
             offsets: self.offsets.clone(),
             width: self.width,
@@ -1440,7 +1490,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
         r: &mut R,
         endian: Endian,
         var_types: &[VarType],
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error> {
         let label_offset = r.stream_position()?;
         let n: u32 = endian.parse(read_bytes(r)?);
@@ -1547,7 +1597,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
             .labels
             .iter()
             .map(|ValueLabel { value, label }| ValueLabel {
-                value: value.clone(),
+                value: *value,
                 label: decoder.decode(label),
             })
             .collect();
@@ -1605,7 +1655,7 @@ impl DocumentRecord<RawDocumentLine> {
         }
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         DecodedRecord::Document(DocumentRecord {
             offsets: self.offsets.clone(),
             lines: self
@@ -1808,7 +1858,7 @@ impl MultipleResponseSet<RawString, RawString> {
         for short_name in self.short_names.iter() {
             if let Some(short_name) = decoder
                 .decode_identifier(short_name)
-                .map_err(|err| Warning::InvalidMrSetName(err))
+                .map_err(Warning::InvalidMrSetName)
                 .issue_warning(&decoder.warn)
             {
                 short_names.push(short_name);
@@ -1817,10 +1867,10 @@ impl MultipleResponseSet<RawString, RawString> {
         Ok(MultipleResponseSet {
             name: decoder
                 .decode_identifier(&self.name)
-                .map_err(|err| Warning::InvalidMrSetVariableName(err))?,
+                .map_err(Warning::InvalidMrSetVariableName)?,
             label: decoder.decode(&self.label),
             mr_type: self.mr_type.clone(),
-            short_names: short_names,
+            short_names,
         })
     }
 }
@@ -1852,7 +1902,7 @@ impl ExtensionRecord for MultipleResponseRecord<RawString, RawString> {
 }
 
 impl MultipleResponseRecord<RawString, RawString> {
-    fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         let mut sets = Vec::new();
         for set in self.0.iter() {
             if let Some(set) = set.decode(decoder).issue_warning(&decoder.warn) {
@@ -1952,7 +2002,7 @@ impl VarDisplayRecord {
         ext: &Extension,
         n_vars: usize,
         endian: Endian,
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Record, Warning> {
         if ext.size != 4 {
             return Err(Warning::BadRecordSize {
@@ -2005,7 +2055,7 @@ where
 }
 
 impl LongStringMissingValues<RawString, RawStr<8>> {
-    fn decode<'a>(
+    fn decode(
         &self,
         decoder: &Decoder,
     ) -> Result<LongStringMissingValues<Identifier, String>, IdError> {
@@ -2075,15 +2125,12 @@ impl ExtensionRecord for LongStringMissingValueRecord<RawString, RawStr<8>> {
 }
 
 impl LongStringMissingValueRecord<RawString, RawStr<8>> {
-    pub fn decode<'a>(
-        &self,
-        decoder: &Decoder,
-    ) -> LongStringMissingValueRecord<Identifier, String> {
+    pub fn decode(&self, decoder: &Decoder) -> LongStringMissingValueRecord<Identifier, String> {
         let mut mvs = Vec::with_capacity(self.0.len());
         for mv in self.0.iter() {
             if let Some(mv) = mv
                 .decode(decoder)
-                .map_err(|err| Warning::InvalidLongStringMissingValueVariableName(err))
+                .map_err(Warning::InvalidLongStringMissingValueVariableName)
                 .issue_warning(&decoder.warn)
             {
                 mvs.push(mv);
@@ -2113,7 +2160,7 @@ impl ExtensionRecord for EncodingRecord {
     }
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Clone, Debug)]
 pub struct NumberOfCasesRecord {
     /// Always observed as 1.
     pub one: u64,
@@ -2168,7 +2215,7 @@ impl TextRecord {
             text: extension.data.into(),
         }
     }
-    pub fn decode<'a>(&self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         match self.rec_type {
             TextRecordType::VariableSets => {
                 DecodedRecord::VariableSets(VariableSetRecord::decode(self, decoder))
@@ -2268,7 +2315,7 @@ impl Attribute {
     }
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Default)]
 pub struct AttributeSet(pub HashMap<Identifier, Vec<String>>);
 
 impl AttributeSet {
@@ -2294,13 +2341,7 @@ impl AttributeSet {
     }
 }
 
-impl Default for AttributeSet {
-    fn default() -> Self {
-        Self(HashMap::default())
-    }
-}
-
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Default)]
 pub struct FileAttributeRecord(AttributeSet);
 
 impl FileAttributeRecord {
@@ -2318,12 +2359,6 @@ impl FileAttributeRecord {
     }
 }
 
-impl Default for FileAttributeRecord {
-    fn default() -> Self {
-        Self(AttributeSet::default())
-    }
-}
-
 #[derive(Clone, Debug)]
 pub struct VarAttributeSet {
     pub long_var_name: Identifier,
@@ -2357,12 +2392,12 @@ impl VariableAttributeRecord {
         let mut var_attribute_sets = Vec::new();
         while !input.is_empty() {
             let Some((var_attribute, rest)) =
-                VarAttributeSet::parse(decoder, &input).issue_warning(&decoder.warn)
+                VarAttributeSet::parse(decoder, input).issue_warning(&decoder.warn)
             else {
                 break;
             };
             var_attribute_sets.push(var_attribute);
-            input = rest.into();
+            input = rest;
         }
         VariableAttributeRecord(var_attribute_sets)
     }
@@ -2530,7 +2565,7 @@ impl Extension {
         r: &mut R,
         endian: Endian,
         n_vars: usize,
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error> {
         let subtype = endian.parse(read_bytes(r)?);
         let header_offset = r.stream_position()?;
@@ -2772,7 +2807,7 @@ impl LongStringValueLabels<RawString, RawString> {
         let mut labels = Vec::with_capacity(self.labels.len());
         for (value, label) in self.labels.iter() {
             let value = decoder.decode_exact_length(&value.0);
-            let label = decoder.decode(&label);
+            let label = decoder.decode(label);
             labels.push((value, label));
         }