work
[pspp] / rust / src / cooked.rs
index 3fac1840638f1eff945ac722d98c0a40d83d7241..ef4b79ccf46cc5ab5e9c219e67dbb4f259c6ffdd 100644 (file)
@@ -1,10 +1,11 @@
 use std::{borrow::Cow, cmp::Ordering, collections::HashMap, iter::repeat};
 
 use crate::{
+    encoding::{get_encoding, Error as EncodingError},
     endian::Endian,
     format::{Error as FormatError, Spec, UncheckedSpec},
     identifier::{Error as IdError, Identifier},
-    raw::{self, MissingValues, UnencodedStr, VarType}, encoding::get_encoding,
+    raw::{self, MissingValues, UnencodedStr, VarType},
 };
 use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
 use encoding_rs::{DecoderResult, Encoding};
@@ -16,6 +17,14 @@ pub use crate::raw::{CategoryLabels, Compression};
 
 #[derive(ThisError, Debug)]
 pub enum Error {
+    // XXX this is really an internal error and maybe we should change the
+    // interfaces to make it impossible
+    #[error("Missing header record")]
+    MissingHeaderRecord,
+
+    #[error("{0}")]
+    EncodingError(EncodingError),
+
     #[error("Variable record at offset {offset:#x} specifies width {width} not in valid range [-1,255).")]
     InvalidVariableWidth { offset: u64, width: i32 },
 
@@ -187,7 +196,16 @@ pub struct Decoder {
     n_generated_names: usize,
 }
 
-pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Vec<Record> {
+pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Result<Vec<Record>, Error> {
+    let Some(header_record) = headers.iter().find_map(|rec| {
+        if let raw::Record::Header(header) = rec {
+            Some(header)
+        } else {
+            None
+        }
+    }) else {
+        return Err(Error::MissingHeaderRecord);
+    };
     let encoding = headers.iter().find_map(|rec| {
         if let raw::Record::Encoding(ref e) = rec {
             Some(e.0.as_str())
@@ -202,12 +220,27 @@ pub fn decode<T>(headers: Vec<raw::Record>, warn: &impl Fn(Error)) -> Vec<Record
             None
         }
     });
-    let encoding = get_encoding(encoding, character_code)
+    let encoding = match get_encoding(encoding, character_code) {
+        Ok(encoding) => encoding,
+        Err(err @ EncodingError::Ebcdic) => return Err(Error::EncodingError(err)),
+        Err(err) => {
+            warn(Error::EncodingError(err));
+            // Warn that we're using the default encoding.
+            
+        }
+    };
 
     let decoder = Decoder {
+        compression: header_record.compression,
+        endian: header_record.endian,
+        encoding,
+        variables: HashMap::new(),
+        var_names: HashMap::new(),
+        n_dict_indexes: 0,
+        n_generated_names: 0,
     };
 
-    Vec::new()
+    unreachable!()
 }
 
 impl Decoder {