start experiment with Encoding, RawStringTrait, MutRawString
authorBen Pfaff <blp@cs.stanford.edu>
Thu, 31 Jul 2025 21:01:36 +0000 (14:01 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Thu, 31 Jul 2025 21:01:36 +0000 (14:01 -0700)
rust/pspp/src/data.rs
rust/pspp/src/data/encoded.rs
rust/pspp/src/dictionary.rs
rust/pspp/src/sys/cooked.rs
rust/pspp/src/sys/test.rs
rust/pspp/src/sys/write.rs

index 3de82de89fd5c4e70c6991da4176f29bf312aaef..294023fdebf2efdd0e2cd906f5b02c5f268fbe4d 100644 (file)
@@ -25,7 +25,7 @@
 //! [Dictionary]: crate::dictionary::Dictionary
 
 // Warn about missing docs, but not for items declared with `#[cfg(test)]`.
-#![cfg_attr(not(test), warn(missing_docs))]
+//#![cfg_attr(not(test), warn(missing_docs))]
 
 use std::{
     borrow::{Borrow, BorrowMut, Cow},
@@ -35,7 +35,7 @@ use std::{
     str::from_utf8,
 };
 
-use encoding_rs::{mem::decode_latin1, Encoding};
+use encoding_rs::{mem::decode_latin1, Encoding, UTF_8};
 use itertools::Itertools;
 use ordered_float::OrderedFloat;
 use serde::{
@@ -48,6 +48,89 @@ use crate::{
     format::DisplayPlain,
 };
 
+pub trait RawStringTrait: Debug + PartialEq + Eq + PartialOrd + Ord {
+    fn raw_string_bytes(&self) -> &[u8];
+
+    /// Compares this string and `other` for equality, ignoring trailing ASCII
+    /// spaces in either string for the purpose of comparison.  (This is
+    /// acceptable because we assume that the encoding is ASCII-compatible.)
+    ///
+    /// This compares the bytes of the strings, disregarding their encodings (if
+    /// known).
+    fn eq_ignore_trailing_spaces<R>(&self, other: &impl RawStringTrait) -> bool {
+        self.raw_string_bytes()
+            .iter()
+            .copied()
+            .zip_longest(other.raw_string_bytes().iter().copied())
+            .all(|elem| {
+                let (left, right) = elem.or(b' ', b' ');
+                left == right
+            })
+    }
+
+    /// Returns true if this raw string can be resized to `len` bytes without
+    /// dropping non-space characters.
+    fn is_resizable(&self, new_len: usize) -> bool {
+        new_len >= self.len()
+            || self.raw_string_bytes()[new_len..]
+                .iter()
+                .copied()
+                .all(|b| b == b' ')
+    }
+
+    fn is_empty(&self) -> bool {
+        self.raw_string_bytes().is_empty()
+    }
+
+    fn len(&self) -> usize {
+        self.raw_string_bytes().len()
+    }
+}
+
+pub trait MutRawString: RawStringTrait {
+    fn resize(&mut self, new_len: usize) -> Result<(), ()>;
+    fn trim_end(&mut self);
+}
+
+impl RawStringTrait for str {
+    fn raw_string_bytes(&self) -> &[u8] {
+        self.as_bytes()
+    }
+}
+
+impl RawStringTrait for String {
+    fn raw_string_bytes(&self) -> &[u8] {
+        self.as_bytes()
+    }
+}
+
+impl RawStringTrait for Vec<u8> {
+    fn raw_string_bytes(&self) -> &[u8] {
+        self.as_slice()
+    }
+}
+
+impl MutRawString for Vec<u8> {
+    fn resize(&mut self, new_len: usize) -> Result<(), ()> {
+        match new_len.cmp(&self.len()) {
+            Ordering::Less => {
+                if !self[new_len..].iter().all(|b| *b == b' ') {
+                    return Err(());
+                }
+                self.truncate(new_len);
+            }
+            Ordering::Equal => (),
+            Ordering::Greater => self.extend((self.len()..new_len).map(|_| b' ')),
+        }
+        Ok(())
+    }
+
+    /// Removes any trailing ASCII spaces.
+    fn trim_end(&mut self) {
+        while self.pop_if(|c| *c == b' ').is_some() {}
+    }
+}
+
 /// A string in an unspecified character encoding.
 ///
 /// `RawString` is usually associated with a [Variable], in the variable's
@@ -305,6 +388,12 @@ pub enum Datum<B> {
     ),
 }
 
+impl Datum<OwnedEncodedString> {
+    pub fn new_utf8(s: impl Into<String>) -> Self {
+        Datum::String(OwnedRawString::from(s.into().into_bytes()).with_encoding(UTF_8))
+    }
+}
+
 impl<B> Debug for Datum<B>
 where
     B: Debug,
@@ -613,6 +702,12 @@ where
     }
 }
 
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum ResizeError {
+    MixedTypes,
+    TooWide,
+}
+
 impl<B> Datum<B>
 where
     B: BorrowMut<OwnedRawString>,
@@ -634,7 +729,7 @@ where
     /// Resizes this datum to the given `width`.  Returns an error, without
     /// modifying the datum, if [is_resizable](Self::is_resizable) would return
     /// false.
-    pub fn resize(&mut self, width: VarWidth) -> Result<(), ()> {
+    pub fn resize(&mut self, width: VarWidth) -> Result<(), ResizeError> {
         match (self, width) {
             (Self::Number(_), VarWidth::Numeric) => Ok(()),
             (Self::String(s), VarWidth::String(new_width)) => {
@@ -643,10 +738,10 @@ where
                     s.resize(new_width as usize);
                     Ok(())
                 } else {
-                    Err(())
+                    Err(ResizeError::TooWide)
                 }
             }
-            _ => Err(()),
+            _ => Err(ResizeError::MixedTypes),
         }
     }
 }
index e5d2d6289743323bbc0821b6facac7949f478b0b..8e32c862d8d4f96664aba0fb9bb978053a3c2ed0 100644 (file)
 use std::{
     borrow::{Borrow, BorrowMut, Cow},
     cmp::Ordering,
-    fmt::Display,
+    fmt::{Debug, Display},
 };
 
 use encoding_rs::{Encoding, UTF_8};
 use serde::Serialize;
 
-use crate::data::{BorrowedRawString, OwnedRawString, Quoted, RawString};
+use crate::data::{BorrowedRawString, OwnedRawString, Quoted, RawString, RawStringTrait};
+
+pub trait Encoded {
+    fn encoding(&self) -> &'static Encoding;
+}
+
+impl Encoded for str {
+    fn encoding(&self) -> &'static Encoding {
+        UTF_8
+    }
+}
+
+impl Encoded for String {
+    fn encoding(&self) -> &'static Encoding {
+        UTF_8
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct WithEncoding<T> {
+    pub inner: T,
+    pub encoding: &'static Encoding,
+}
+
+impl<T> WithEncoding<T> {
+    pub fn new(inner: T, encoding: &'static Encoding) -> Self {
+        Self { inner, encoding }
+    }
+
+    pub fn into_inner(self) -> T {
+        self.inner
+    }
+}
+
+impl<T> PartialOrd for WithEncoding<T>
+where
+    T: PartialOrd,
+{
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        self.inner.partial_cmp(&other.inner)
+    }
+}
+
+impl<T> Ord for WithEncoding<T>
+where
+    T: Ord,
+{
+    fn cmp(&self, other: &Self) -> Ordering {
+        self.inner.cmp(&other.inner)
+    }
+}
+
+pub trait EncodedStringTrait: Encoded + RawStringTrait + Display + Debug {
+    fn as_str(&self) -> Cow<'_, str>;
+    fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]>;
+}
+
+impl<'a> EncodedStringTrait for str {
+    fn as_str(&self) -> Cow<'_, str> {
+        Cow::from(self)
+    }
+
+    fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]> {
+        encoding.encode(self).0
+    }
+}
+
+impl<T> RawStringTrait for WithEncoding<T>
+where
+    T: RawStringTrait,
+{
+    fn raw_string_bytes(&self) -> &[u8] {
+        self.inner.raw_string_bytes()
+    }
+}
+
+impl<T> EncodedStringTrait for WithEncoding<T>
+where
+    T: RawStringTrait,
+{
+    /// Returns this string recoded in UTF-8.  Invalid characters will be
+    /// replaced by [REPLACEMENT_CHARACTER].
+    ///
+    /// [REPLACEMENT_CHARACTER]: std::char::REPLACEMENT_CHARACTER
+    fn as_str(&self) -> Cow<'_, str> {
+        self.encoding
+            .decode_without_bom_handling(self.raw_string_bytes())
+            .0
+    }
+
+    /// Returns this string recoded in `encoding`.  Invalid characters will be
+    /// replaced by [REPLACEMENT_CHARACTER].
+    ///
+    /// [REPLACEMENT_CHARACTER]: std::char::REPLACEMENT_CHARACTER
+    fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]> {
+        let utf8 = self.as_str();
+        match encoding.encode(&utf8).0 {
+            Cow::Borrowed(_) => {
+                // Recoding into UTF-8 and then back did not change anything.
+                Cow::from(self.raw_string_bytes())
+            }
+            Cow::Owned(owned) => Cow::Owned(owned),
+        }
+    }
+}
+
+impl<T> Encoded for WithEncoding<T> {
+    fn encoding(&self) -> &'static Encoding {
+        self.encoding
+    }
+}
+
+impl<T> Display for WithEncoding<T>
+where
+    T: RawStringTrait,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.write_str(&self.as_str())
+    }
+}
 
 pub type OwnedEncodedString = EncodedString<OwnedRawString>;
 pub type BorrowedEncodedString<'a> = EncodedString<&'a BorrowedRawString>;
@@ -24,6 +143,12 @@ pub struct EncodedString<R> {
     pub encoding: &'static Encoding,
 }
 
+impl<R> Encoded for EncodedString<R> {
+    fn encoding(&self) -> &'static Encoding {
+        self.encoding
+    }
+}
+
 impl<R> EncodedString<R>
 where
     R: Borrow<BorrowedRawString>,
index 0198c3bb7e9b3c392afbe4a8349457211c458e01..7cd40e86512815463c2a0f5d1015916be080fc37 100644 (file)
@@ -23,7 +23,7 @@ use std::{
     collections::{btree_set, BTreeMap, BTreeSet, HashMap, HashSet},
     fmt::{Debug, Display, Formatter, Result as FmtResult},
     hash::{DefaultHasher, Hash, Hasher},
-    ops::{Bound, Index, Not, RangeBounds, RangeInclusive},
+    ops::{Bound, Deref, Index, Not, RangeBounds, RangeInclusive},
     str::FromStr,
 };
 
@@ -40,7 +40,7 @@ use thiserror::Error as ThisError;
 use unicase::UniCase;
 
 use crate::{
-    data::{AsEncodedString, Datum, OwnedEncodedString, OwnedRawString},
+    data::{AsEncodedString, Datum, OwnedEncodedString, OwnedRawString, ResizeError},
     format::{DisplayPlain, Format},
     identifier::{ByIdentifier, HasIdentifier, Identifier},
     output::pivot::{
@@ -1291,7 +1291,7 @@ pub struct Variable {
     /// `None`).
     ///
     /// Both kinds of missing values are excluded from most analyses.
-    pub missing_values: MissingValues,
+    missing_values: MissingValues,
 
     /// Output format used in most contexts.
     pub print_format: Format,
@@ -1401,6 +1401,17 @@ impl Variable {
 
         self.width = width;
     }
+
+    pub fn missing_values(&self) -> &MissingValues {
+        &self.missing_values
+    }
+
+    pub fn missing_values_mut(&mut self) -> MissingValuesMut<'_> {
+        MissingValuesMut {
+            inner: &mut self.missing_values,
+            width: self.width,
+        }
+    }
 }
 
 impl HasIdentifier for Variable {
@@ -1954,6 +1965,69 @@ impl Hash for ValueLabels {
     }
 }
 
+pub struct MissingValuesMut<'a> {
+    inner: &'a mut MissingValues,
+    width: VarWidth,
+}
+
+impl<'a> Deref for MissingValuesMut<'a> {
+    type Target = MissingValues;
+
+    fn deref(&self) -> &Self::Target {
+        self.inner
+    }
+}
+
+impl<'a> MissingValuesMut<'a> {
+    pub fn replace(&mut self, mut new: MissingValues) -> Result<(), MissingValuesError> {
+        new.resize(self.width)?;
+        *self.inner = new;
+        Ok(())
+    }
+
+    pub fn add_value(
+        &mut self,
+        mut value: Datum<OwnedEncodedString>,
+    ) -> Result<(), MissingValuesError> {
+        if self.inner.values.len() > 2
+            || (self.inner.range().is_some() && self.inner.values.len() > 1)
+        {
+            Err(MissingValuesError::TooMany)
+        } else if value.var_type() != VarType::from(self.width) {
+            Err(MissingValuesError::MixedTypes)
+        } else if value.resize(self.width).is_err() {
+            Err(MissingValuesError::TooWide)
+        } else {
+            value.trim_end();
+            self.inner.values.push(value);
+            Ok(())
+        }
+    }
+
+    pub fn add_values(
+        &mut self,
+        values: impl IntoIterator<Item = Datum<OwnedEncodedString>>,
+    ) -> Result<(), MissingValuesError> {
+        let n = self.inner.values.len();
+        for value in values {
+            self.add_value(value)
+                .inspect_err(|_| self.inner.values.truncate(n))?;
+        }
+        Ok(())
+    }
+
+    pub fn add_range(&mut self, range: MissingValueRange) -> Result<(), MissingValuesError> {
+        if self.inner.range.is_some() || self.inner.values().len() > 1 {
+            Err(MissingValuesError::TooMany)
+        } else if self.width != VarWidth::Numeric {
+            Err(MissingValuesError::MixedTypes)
+        } else {
+            self.inner.range = Some(range);
+            Ok(())
+        }
+    }
+}
+
 #[derive(Clone, Default, Serialize)]
 pub struct MissingValues {
     /// Individual missing values, up to 3 of them.
@@ -1999,6 +2073,15 @@ pub enum MissingValuesError {
     MixedTypes,
 }
 
+impl From<ResizeError> for MissingValuesError {
+    fn from(value: ResizeError) -> Self {
+        match value {
+            ResizeError::MixedTypes => MissingValuesError::MixedTypes,
+            ResizeError::TooWide => MissingValuesError::TooWide,
+        }
+    }
+}
+
 impl MissingValues {
     pub fn clear(&mut self) {
         *self = Self::default();
@@ -2073,10 +2156,11 @@ impl MissingValues {
         }
     }
 
-    pub fn resize(&mut self, width: VarWidth) -> Result<(), ()> {
-        fn inner(this: &mut MissingValues, width: VarWidth) -> Result<(), ()> {
+    pub fn resize(&mut self, width: VarWidth) -> Result<(), MissingValuesError> {
+        fn inner(this: &mut MissingValues, width: VarWidth) -> Result<(), MissingValuesError> {
             for datum in &mut this.values {
                 datum.resize(width)?;
+                datum.trim_end();
             }
             if let Some(range) = &mut this.range {
                 range.resize(width)?;
@@ -2126,11 +2210,11 @@ impl MissingValueRange {
         }
     }
 
-    pub fn resize(&self, width: VarWidth) -> Result<(), ()> {
+    pub fn resize(&self, width: VarWidth) -> Result<(), MissingValuesError> {
         if width.is_numeric() {
             Ok(())
         } else {
-            Err(())
+            Err(MissingValuesError::MixedTypes)
         }
     }
 }
index 448f2471d570d81f0da47e5a53f5896c1b725025..82edb76f92dad676690fe24c7b6baafaa68011ba 100644 (file)
@@ -913,7 +913,9 @@ impl Records {
 
             variable.label = input.label.clone();
 
-            variable.missing_values = input.missing_values.decode(encoding).unwrap();
+            variable
+                .missing_values_mut()
+                .replace(input.missing_values.decode(encoding).unwrap());
 
             variable.print_format = decode_format(
                 input.print_format,
@@ -1262,7 +1264,10 @@ impl Records {
                 })
                 .collect::<Vec<_>>();
             match MissingValues::new(values, None) {
-                Ok(missing_values) => variable.missing_values = missing_values,
+                Ok(missing_values) => variable
+                    .missing_values_mut()
+                    .replace(missing_values)
+                    .unwrap(),
                 Err(MissingValuesError::TooWide) => {
                     warn(Error::MissingValuesTooWide(record.var_name.clone()))
                 }
index 7fa06d650b9a1a3e54f62e43bccf508a4038d845..24463aa6327680ce8ad658b945ce43f34e902ede 100644 (file)
@@ -24,12 +24,11 @@ use std::{
 
 use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
 use encoding_rs::UTF_8;
-use hexplay::HexView;
 
 use crate::{
     crypto::EncryptedFile,
-    data::{BorrowedDatum, Datum},
-    dictionary::{self, Dictionary, VarWidth, Variable},
+    data::{BorrowedDatum, Datum, OwnedDatum, RawString},
+    dictionary::{Dictionary, VarWidth, Variable},
     endian::Endian,
     identifier::Identifier,
     output::{
@@ -573,6 +572,25 @@ fn encrypted_file_without_password() {
     ));
 }
 
+impl WriteOptions {
+    /// Returns a [WriteOptions] with the given `compression` and the other
+    /// members set to fixed values so that running at different times or with
+    /// different crate names or versions won't change what's written to the
+    /// file.
+    fn reproducible(compression: Option<Compression>) -> Self {
+        WriteOptions::new()
+            .with_compression(compression)
+            .with_timestamp(NaiveDateTime::new(
+                NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(),
+                NaiveTime::from_hms_opt(15, 7, 55).unwrap(),
+            ))
+            .with_product_name(Cow::from("PSPP TEST DATA FILE"))
+            .with_product_version(ProductVersion(1, 2, 3))
+    }
+}
+
+/// Tests the most basic kind of writing a system file, just writing a few
+/// numeric variables and cases.
 #[test]
 fn write_numeric() {
     for (compression, compression_string) in [
@@ -587,14 +605,7 @@ fn write_numeric() {
                 .add_var(Variable::new(name, VarWidth::Numeric, UTF_8))
                 .unwrap();
         }
-        let mut cases = WriteOptions::new()
-            .with_compression(compression)
-            .with_timestamp(NaiveDateTime::new(
-                NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(),
-                NaiveTime::from_hms_opt(15, 7, 55).unwrap(),
-            ))
-            .with_product_name(Cow::from("PSPP TEST DATA FILE"))
-            .with_product_version(ProductVersion(1, 2, 3))
+        let mut cases = WriteOptions::reproducible(compression)
             .write_writer(&dictionary, Cursor::new(Vec::new()))
             .unwrap();
         for case in [
@@ -623,6 +634,62 @@ fn write_numeric() {
     }
 }
 
+/// Tests writing long string value labels and missing values.
+#[test]
+fn write_long_string_value_labels() {
+    for (compression, compression_string) in [
+        (None, "uncompressed"),
+        (Some(Compression::Simple), "simple"),
+        (Some(Compression::ZLib), "zlib"),
+    ] {
+        let mut dictionary = Dictionary::new(UTF_8);
+        let mut s1 = Variable::new(Identifier::new("s1").unwrap(), VarWidth::String(9), UTF_8);
+        s1.value_labels.insert(
+            OwnedDatum::String(RawString(String::from("abc      ").into_bytes())),
+            String::from("First value label"),
+        );
+        s1.value_labels.insert(
+            OwnedDatum::String(RawString(String::from("abcdefgh ").into_bytes())),
+            String::from("Second value label"),
+        );
+        s1.value_labels.insert(
+            OwnedDatum::String(RawString(String::from("abcdefghi").into_bytes())),
+            String::from("Third value label"),
+        );
+        s1.missing_values_mut()
+            .add_values([Datum::new_utf8("0")])
+            .unwrap();
+        dictionary.add_var(s1).unwrap();
+        /*
+        let mut cases = WriteOptions::reproducible(compression)
+            .write_writer(&dictionary, Cursor::new(Vec::new()))
+            .unwrap();
+        for case in [
+            [1, 1, 1, 2],
+            [1, 1, 2, 30],
+            [1, 2, 1, 8],
+            [1, 2, 2, 20],
+            [2, 1, 1, 2],
+            [2, 1, 2, 22],
+            [2, 2, 1, 1],
+            [2, 2, 2, 3],
+        ] {
+            cases
+                .write_case(
+                    case.into_iter()
+                        .map(|number| BorrowedDatum::Number(Some(number as f64))),
+                )
+                .unwrap();
+        }
+        let sysfile = cases.finish().unwrap().unwrap().into_inner();
+        let expected_filename = PathBuf::from(&format!(
+            "src/sys/testdata/write-numeric-{compression_string}.expected"
+        ));
+        let expected = String::from_utf8(std::fs::read(&expected_filename).unwrap()).unwrap();
+        test_sysfile(Cursor::new(sysfile), &expected, &expected_filename);*/
+    }
+}
+
 fn test_raw_sysfile(name: &str) {
     let input_filename = Path::new("src/sys/testdata")
         .join(name)
index cdbb5e6028662d6293528d188a0148a8cdfcf278..a4da7e680036eb4f5cdd13f2f88cc9309b2258b4 100644 (file)
@@ -276,8 +276,8 @@ where
                 width: seg0_width.as_string_width().unwrap_or(0) as i32,
                 has_variable_label: variable.label.is_some() as u32,
                 missing_value_code: if variable.width.is_long_string() {
-                    let n = variable.missing_values.values().len() as i32;
-                    match variable.missing_values.range() {
+                    let n = variable.missing_values().values().len() as i32;
+                    match variable.missing_values().range() {
                         Some(_) => -(n + 2),
                         None => n,
                     }
@@ -300,14 +300,14 @@ where
 
             // Missing values.
             if !variable.width.is_long_string() {
-                if let Some(range) = variable.missing_values.range() {
+                if let Some(range) = variable.missing_values().range() {
                     (
                         range.low().unwrap_or(f64::MIN),
                         range.high().unwrap_or(f64::MAX),
                     )
                         .write_le(self.writer)?;
                 }
-                variable.missing_values.values().write_le(self.writer)?;
+                variable.missing_values().values().write_le(self.writer)?;
             }
             write_variable_continuation_records(&mut self.writer, seg0_width)?;
 
@@ -601,21 +601,22 @@ where
         let mut body = Vec::new();
         let mut cursor = Cursor::new(&mut body);
         for variable in &self.dictionary.variables {
-            if variable.missing_values.is_empty() || !variable.width.is_long_string() {
+            if variable.missing_values().is_empty() || !variable.width.is_long_string() {
                 break;
             }
             let name = self.dictionary.encoding().encode(&variable.name).0;
             (
                 name.len() as u32,
                 &name[..],
-                variable.missing_values.values().len() as u32,
+                variable.missing_values().values().len() as u32,
                 8u32,
             )
                 .write_le(&mut cursor)?;
 
-            for value in variable.missing_values.values() {
-                let value = value.as_string().unwrap();
-                value.as_bytes()[..8].write_le(&mut cursor).unwrap();
+            for value in variable.missing_values().values() {
+                let value = value.as_string().unwrap().as_bytes();
+                let bytes = value.get(..8).unwrap_or(value);
+                Padded::exact(bytes, 8, b' ').write_le(&mut cursor).unwrap();
             }
         }
         self.write_bytes_record(22, &body)