decode weight variables too
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 21 Dec 2024 21:18:45 +0000 (13:18 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 21 Dec 2024 21:18:45 +0000 (13:18 -0800)
rust/pspp/src/cooked.rs
rust/pspp/src/dictionary.rs
rust/pspp/src/raw.rs

index 4d46dc4ead6b2a3e943fdef5d2b72f7d061eecc9..f5cd59af93e6ec52b2e3a6adaeddf5bf545b9c41 100644 (file)
@@ -17,6 +17,7 @@ use crate::{
 };
 use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
 use encoding_rs::Encoding;
+use indexmap::set::MutableValues;
 use num::Integer;
 use thiserror::Error as ThisError;
 
@@ -522,28 +523,33 @@ pub fn decode(
         assert_eq!(var_index_map.insert(value_index, dict_index), None);
     }
 
+    if let Some(weight_index) = headers.header.weight_index {
+        if let Some(dict_index) = var_index_map.get(&(weight_index as usize - 1)) {
+            let variable = &dictionary.variables[*dict_index];
+            if variable.is_numeric() {
+                dictionary.weight = Some(*dict_index);
+            } else {
+                warn(Error::TBD);
+            }
+        } else {
+            warn(Error::TBD);
+        }
+    }
+
     for record in headers.value_label.drain(..) {
         let mut dict_indexes = Vec::with_capacity(record.dict_indexes.len());
-        let mut continuation_indexes = Vec::new();
         let mut long_string_variables = Vec::new();
         for value_index in record.dict_indexes.iter() {
-            if let Some(dict_index) = var_index_map.get(&(*value_index as usize - 1)) {
-                let variable = &dictionary.variables[*dict_index];
-                if variable.width.is_long_string() {
-                    long_string_variables.push(variable.name.clone());
-                } else {
-                    dict_indexes.push(*dict_index);
-                }
+            let Some(dict_index) = var_index_map.get(&(*value_index as usize - 1)) else {
+                unreachable!()
+            };
+            let variable = &dictionary.variables[*dict_index];
+            if variable.width.is_long_string() {
+                long_string_variables.push(variable.name.clone());
             } else {
-                continuation_indexes.push(*value_index);
+                dict_indexes.push(*dict_index);
             }
         }
-        if !continuation_indexes.is_empty() {
-            warn(Error::LongStringContinuationIndexes {
-                offset: record.offsets.start,
-                indexes: continuation_indexes,
-            });
-        }
         if !long_string_variables.is_empty() {
             warn(Error::InvalidLongStringValueLabels {
                 offsets: record.offsets.clone(),
index e8dcec2f2051b7ddfdb995441bfd1c67c5eec477..5da085b2e5e0747087ffb6cec0daf3043cdb2b79 100644 (file)
@@ -105,6 +105,14 @@ impl VarWidth {
             VarWidth::String(width) => Some(*width as usize),
         }
     }
+
+    pub fn is_numeric(&self) -> bool {
+        *self == Self::Numeric
+    }
+
+    pub fn is_string(&self) -> bool {
+        !self.is_numeric()
+    }
 }
 
 impl From<VarWidth> for VarType {
@@ -235,6 +243,17 @@ impl Dictionary {
         }
     }
 
+    pub fn weight_var(&self) -> Option<&Variable> {
+        self.weight.map(|index| &self.variables[index].0)
+    }
+
+    pub fn split_vars(&self) -> Vec<&Variable> {
+        self.split_file
+            .iter()
+            .map(|index| &self.variables[*index].0)
+            .collect()
+    }
+
     pub fn add_var(&mut self, variable: Variable) -> Result<usize, DuplicateVariableName> {
         let (index, inserted) = self.variables.insert_full(ByIdentifier::new(variable));
         if inserted {
@@ -453,6 +472,14 @@ impl Variable {
             attributes: HashSet::new(),
         }
     }
+
+    pub fn is_numeric(&self) -> bool {
+        self.width.is_numeric()
+    }
+
+    pub fn is_string(&self) -> bool {
+        self.width.is_string()
+    }
 }
 
 impl HasIdentifier for Variable {
index 5b2a3eb55f898baa47263e43e9326a11885493dc..b8fc79e7cf249e2e9b90f6ce7a92e7e967766023 100644 (file)
@@ -129,7 +129,7 @@ pub enum Warning {
         wrong_types: Vec<u32>,
     },
 
-    #[error("At offset {offset:#x}, one or more variable indexes for value labels were not in the valid range [1,{max}]: {invalid:?}")]
+    #[error("At offset {offset:#x}, one or more variable indexes for value labels were not in the valid range [1,{max}] or referred to string continuations: {invalid:?}")]
     InvalidVarIndexes {
         offset: u64,
         max: usize,
@@ -268,7 +268,7 @@ impl Record {
     fn read<R>(
         reader: &mut R,
         endian: Endian,
-        var_types: &[VarType],
+        var_types: &VarTypes,
         warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error>
     where
@@ -279,7 +279,7 @@ impl Record {
             2 => Ok(Some(VariableRecord::read(reader, endian)?)),
             3 => Ok(ValueLabelRecord::read(reader, endian, var_types, warn)?),
             6 => Ok(Some(DocumentRecord::read(reader, endian)?)),
-            7 => Extension::read(reader, endian, var_types.len(), warn),
+            7 => Extension::read(reader, endian, var_types.n_values(), warn),
             999 => Ok(Some(Record::EndOfHeaders(
                 endian.parse(read_bytes(reader)?),
             ))),
@@ -730,12 +730,12 @@ impl RawValue {
 
     fn read_case<R: Read + Seek>(
         reader: &mut R,
-        var_types: &[VarType],
+        var_types: &VarTypes,
         endian: Endian,
     ) -> Result<Option<Vec<Self>>, Error> {
         let case_start = reader.stream_position()?;
-        let mut values = Vec::with_capacity(var_types.len());
-        for (i, &var_type) in var_types.iter().enumerate() {
+        let mut values = Vec::with_capacity(var_types.n_values());
+        for (i, (var_type, _)) in var_types.types.iter().enumerate() {
             let Some(raw) = try_read_bytes(reader)? else {
                 if i == 0 {
                     return Ok(None);
@@ -744,25 +744,25 @@ impl RawValue {
                     return Err(Error::EofInCase {
                         offset,
                         case_ofs: offset - case_start,
-                        case_len: var_types.len() * 8,
+                        case_len: var_types.n_values() * 8,
                     });
                 }
             };
-            values.push(Value::from_raw(&UntypedValue(raw), var_type, endian));
+            values.push(Value::from_raw(&UntypedValue(raw), *var_type, endian));
         }
         Ok(Some(values))
     }
 
     fn read_compressed_case<R: Read + Seek>(
         reader: &mut R,
-        var_types: &[VarType],
+        var_types: &VarTypes,
         codes: &mut VecDeque<u8>,
         endian: Endian,
         bias: f64,
     ) -> Result<Option<Vec<Self>>, Error> {
         let case_start = reader.stream_position()?;
-        let mut values = Vec::with_capacity(var_types.len());
-        for (i, &var_type) in var_types.iter().enumerate() {
+        let mut values = Vec::with_capacity(var_types.n_values());
+        for (i, (var_type, _)) in var_types.types.iter().enumerate() {
             let value = loop {
                 let Some(code) = codes.pop_front() else {
                     let Some(new_codes): Option<[u8; 8]> = try_read_bytes(reader)? else {
@@ -781,7 +781,7 @@ impl RawValue {
                 };
                 match code {
                     0 => (),
-                    1..=251 => match var_type {
+                    1..=251 => match *var_type {
                         VarType::Numeric => break Self::Number(Some(code as f64 - bias)),
                         VarType::String => {
                             break Self::String(RawStr(endian.to_bytes(code as f64 - bias)))
@@ -799,9 +799,9 @@ impl RawValue {
                         }
                     }
                     253 => {
-                        break Self::from_raw(&UntypedValue(read_bytes(reader)?), var_type, endian)
+                        break Self::from_raw(&UntypedValue(read_bytes(reader)?), *var_type, endian)
                     }
-                    254 => match var_type {
+                    254 => match *var_type {
                         VarType::String => break Self::String(RawStr(*b"        ")), // XXX EBCDIC
                         VarType::Numeric => {
                             return Err(Error::CompressedStringExpected {
@@ -810,7 +810,7 @@ impl RawValue {
                             })
                         }
                     },
-                    255 => match var_type {
+                    255 => match *var_type {
                         VarType::Numeric => break Self::Number(None),
                         VarType::String => {
                             return Err(Error::CompressedNumberExpected {
@@ -898,7 +898,7 @@ where
     warn: Box<dyn Fn(Warning)>,
 
     header: HeaderRecord<RawString>,
-    var_types: Vec<VarType>,
+    var_types: VarTypes,
 
     state: ReaderState,
 }
@@ -916,7 +916,7 @@ where
             reader: Some(reader),
             warn: Box::new(warn),
             header,
-            var_types: Vec::new(),
+            var_types: VarTypes::new(),
             state: ReaderState::Start,
         })
     }
@@ -939,7 +939,7 @@ where
                     match Record::read(
                         self.reader.as_mut().unwrap(),
                         self.header.endian,
-                        self.var_types.as_slice(),
+                        &self.var_types,
                         &self.warn,
                     ) {
                         Ok(Some(record)) => break record,
@@ -948,13 +948,7 @@ where
                     }
                 };
                 match record {
-                    Record::Variable(VariableRecord { width, .. }) => {
-                        self.var_types.push(if width == 0 {
-                            VarType::Numeric
-                        } else {
-                            VarType::String
-                        });
-                    }
+                    Record::Variable(VariableRecord { width, .. }) => self.var_types.push(width),
                     Record::EndOfHeaders(_) => {
                         self.state = if let Some(Compression::ZLib) = self.header.compression {
                             ReaderState::ZlibHeader
@@ -1019,7 +1013,7 @@ impl<T> ReadSeek for T where T: Read + Seek {}
 
 pub struct Cases {
     reader: Box<dyn ReadSeek>,
-    var_types: Vec<VarType>,
+    var_types: VarTypes,
     compression: Option<Compression>,
     bias: f64,
     endian: Endian,
@@ -1034,7 +1028,7 @@ impl Debug for Cases {
 }
 
 impl Cases {
-    fn new<R>(reader: R, var_types: Vec<VarType>, header: &HeaderRecord<RawString>) -> Self
+    fn new<R>(reader: R, var_types: VarTypes, header: &HeaderRecord<RawString>) -> Self
     where
         R: Read + Seek + 'static,
     {
@@ -1502,7 +1496,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
     fn read<R: Read + Seek>(
         r: &mut R,
         endian: Endian,
-        var_types: &[VarType],
+        var_types: &VarTypes,
         warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error> {
         let label_offset = r.stream_position()?;
@@ -1553,10 +1547,9 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
         let index_offset = r.stream_position()?;
         let mut dict_indexes = Vec::with_capacity(n as usize);
         let mut invalid_indexes = Vec::new();
-        let valid_range = 1..=var_types.len();
         for _ in 0..n {
             let index: u32 = endian.parse(read_bytes(r)?);
-            if valid_range.contains(&(index as usize)) {
+            if var_types.is_valid_index(index as usize) {
                 dict_indexes.push(index);
             } else {
                 invalid_indexes.push(index);
@@ -1565,7 +1558,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
         if !invalid_indexes.is_empty() {
             warn(Warning::InvalidVarIndexes {
                 offset: index_offset,
-                max: var_types.len(),
+                max: var_types.n_values(),
                 invalid: invalid_indexes,
             });
         }
@@ -1573,10 +1566,10 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
         let Some(&first_index) = dict_indexes.first() else {
             return Ok(None);
         };
-        let var_type = var_types[first_index as usize - 1];
+        let var_type = var_types.types[first_index as usize - 1].0;
         let mut wrong_type_indexes = Vec::new();
         dict_indexes.retain(|&index| {
-            if var_types[index as usize - 1] != var_type {
+            if var_types.types[index as usize - 1].0 != var_type {
                 wrong_type_indexes.push(index);
                 false
             } else {
@@ -2887,3 +2880,44 @@ impl LongStringValueLabelRecord<RawString, RawString> {
         LongStringValueLabelRecord(labels)
     }
 }
+
+#[derive(Default)]
+pub struct VarTypes {
+    pub types: Vec<(VarType, usize)>,
+}
+
+impl VarTypes {
+    pub fn new() -> Self {
+        Self::default()
+    }
+
+    pub fn push(&mut self, width: i32) {
+        let var_type = match width {
+            -1 => return,
+            0 => VarType::Numeric,
+            1..=255 => VarType::String,
+            _ => unreachable!(),
+        };
+        let n_values = (width as usize).div_ceil(8).max(1);
+        for i in 0..n_values {
+            self.types.push((var_type, i));
+        }
+    }
+
+    pub fn n_values(&self) -> usize {
+        self.types.len()
+    }
+
+    pub fn is_valid_index(&self, index: usize) -> bool {
+        self.var_type_at(index).is_some()
+    }
+
+    pub fn var_type_at(&self, index: usize) -> Option<VarType> {
+        if index >= 1 && index <= self.types.len() {
+            if let (var_type, 0) = self.types[index - 1] {
+                return Some(var_type);
+            }
+        }
+        None
+    }
+}