From: Ben Pfaff Date: Thu, 31 Jul 2025 21:01:36 +0000 (-0700) Subject: start experiment with Encoding, RawStringTrait, MutRawString X-Git-Url: https://pintos-os.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a722de3544c0f2e2e6d69054e6fce1f6fefe4cf;p=pspp start experiment with Encoding, RawStringTrait, MutRawString --- diff --git a/rust/pspp/src/data.rs b/rust/pspp/src/data.rs index 3de82de89f..294023fdeb 100644 --- a/rust/pspp/src/data.rs +++ b/rust/pspp/src/data.rs @@ -25,7 +25,7 @@ //! [Dictionary]: crate::dictionary::Dictionary // Warn about missing docs, but not for items declared with `#[cfg(test)]`. -#![cfg_attr(not(test), warn(missing_docs))] +//#![cfg_attr(not(test), warn(missing_docs))] use std::{ borrow::{Borrow, BorrowMut, Cow}, @@ -35,7 +35,7 @@ use std::{ str::from_utf8, }; -use encoding_rs::{mem::decode_latin1, Encoding}; +use encoding_rs::{mem::decode_latin1, Encoding, UTF_8}; use itertools::Itertools; use ordered_float::OrderedFloat; use serde::{ @@ -48,6 +48,89 @@ use crate::{ format::DisplayPlain, }; +pub trait RawStringTrait: Debug + PartialEq + Eq + PartialOrd + Ord { + fn raw_string_bytes(&self) -> &[u8]; + + /// Compares this string and `other` for equality, ignoring trailing ASCII + /// spaces in either string for the purpose of comparison. (This is + /// acceptable because we assume that the encoding is ASCII-compatible.) + /// + /// This compares the bytes of the strings, disregarding their encodings (if + /// known). + fn eq_ignore_trailing_spaces(&self, other: &impl RawStringTrait) -> bool { + self.raw_string_bytes() + .iter() + .copied() + .zip_longest(other.raw_string_bytes().iter().copied()) + .all(|elem| { + let (left, right) = elem.or(b' ', b' '); + left == right + }) + } + + /// Returns true if this raw string can be resized to `len` bytes without + /// dropping non-space characters. + fn is_resizable(&self, new_len: usize) -> bool { + new_len >= self.len() + || self.raw_string_bytes()[new_len..] + .iter() + .copied() + .all(|b| b == b' ') + } + + fn is_empty(&self) -> bool { + self.raw_string_bytes().is_empty() + } + + fn len(&self) -> usize { + self.raw_string_bytes().len() + } +} + +pub trait MutRawString: RawStringTrait { + fn resize(&mut self, new_len: usize) -> Result<(), ()>; + fn trim_end(&mut self); +} + +impl RawStringTrait for str { + fn raw_string_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl RawStringTrait for String { + fn raw_string_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl RawStringTrait for Vec { + fn raw_string_bytes(&self) -> &[u8] { + self.as_slice() + } +} + +impl MutRawString for Vec { + fn resize(&mut self, new_len: usize) -> Result<(), ()> { + match new_len.cmp(&self.len()) { + Ordering::Less => { + if !self[new_len..].iter().all(|b| *b == b' ') { + return Err(()); + } + self.truncate(new_len); + } + Ordering::Equal => (), + Ordering::Greater => self.extend((self.len()..new_len).map(|_| b' ')), + } + Ok(()) + } + + /// Removes any trailing ASCII spaces. + fn trim_end(&mut self) { + while self.pop_if(|c| *c == b' ').is_some() {} + } +} + /// A string in an unspecified character encoding. /// /// `RawString` is usually associated with a [Variable], in the variable's @@ -305,6 +388,12 @@ pub enum Datum { ), } +impl Datum { + pub fn new_utf8(s: impl Into) -> Self { + Datum::String(OwnedRawString::from(s.into().into_bytes()).with_encoding(UTF_8)) + } +} + impl Debug for Datum where B: Debug, @@ -613,6 +702,12 @@ where } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResizeError { + MixedTypes, + TooWide, +} + impl Datum where B: BorrowMut, @@ -634,7 +729,7 @@ where /// Resizes this datum to the given `width`. Returns an error, without /// modifying the datum, if [is_resizable](Self::is_resizable) would return /// false. - pub fn resize(&mut self, width: VarWidth) -> Result<(), ()> { + pub fn resize(&mut self, width: VarWidth) -> Result<(), ResizeError> { match (self, width) { (Self::Number(_), VarWidth::Numeric) => Ok(()), (Self::String(s), VarWidth::String(new_width)) => { @@ -643,10 +738,10 @@ where s.resize(new_width as usize); Ok(()) } else { - Err(()) + Err(ResizeError::TooWide) } } - _ => Err(()), + _ => Err(ResizeError::MixedTypes), } } } diff --git a/rust/pspp/src/data/encoded.rs b/rust/pspp/src/data/encoded.rs index e5d2d62897..8e32c862d8 100644 --- a/rust/pspp/src/data/encoded.rs +++ b/rust/pspp/src/data/encoded.rs @@ -1,13 +1,132 @@ use std::{ borrow::{Borrow, BorrowMut, Cow}, cmp::Ordering, - fmt::Display, + fmt::{Debug, Display}, }; use encoding_rs::{Encoding, UTF_8}; use serde::Serialize; -use crate::data::{BorrowedRawString, OwnedRawString, Quoted, RawString}; +use crate::data::{BorrowedRawString, OwnedRawString, Quoted, RawString, RawStringTrait}; + +pub trait Encoded { + fn encoding(&self) -> &'static Encoding; +} + +impl Encoded for str { + fn encoding(&self) -> &'static Encoding { + UTF_8 + } +} + +impl Encoded for String { + fn encoding(&self) -> &'static Encoding { + UTF_8 + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct WithEncoding { + pub inner: T, + pub encoding: &'static Encoding, +} + +impl WithEncoding { + pub fn new(inner: T, encoding: &'static Encoding) -> Self { + Self { inner, encoding } + } + + pub fn into_inner(self) -> T { + self.inner + } +} + +impl PartialOrd for WithEncoding +where + T: PartialOrd, +{ + fn partial_cmp(&self, other: &Self) -> Option { + self.inner.partial_cmp(&other.inner) + } +} + +impl Ord for WithEncoding +where + T: Ord, +{ + fn cmp(&self, other: &Self) -> Ordering { + self.inner.cmp(&other.inner) + } +} + +pub trait EncodedStringTrait: Encoded + RawStringTrait + Display + Debug { + fn as_str(&self) -> Cow<'_, str>; + fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]>; +} + +impl<'a> EncodedStringTrait for str { + fn as_str(&self) -> Cow<'_, str> { + Cow::from(self) + } + + fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]> { + encoding.encode(self).0 + } +} + +impl RawStringTrait for WithEncoding +where + T: RawStringTrait, +{ + fn raw_string_bytes(&self) -> &[u8] { + self.inner.raw_string_bytes() + } +} + +impl EncodedStringTrait for WithEncoding +where + T: RawStringTrait, +{ + /// Returns this string recoded in UTF-8. Invalid characters will be + /// replaced by [REPLACEMENT_CHARACTER]. + /// + /// [REPLACEMENT_CHARACTER]: std::char::REPLACEMENT_CHARACTER + fn as_str(&self) -> Cow<'_, str> { + self.encoding + .decode_without_bom_handling(self.raw_string_bytes()) + .0 + } + + /// Returns this string recoded in `encoding`. Invalid characters will be + /// replaced by [REPLACEMENT_CHARACTER]. + /// + /// [REPLACEMENT_CHARACTER]: std::char::REPLACEMENT_CHARACTER + fn to_encoding(&self, encoding: &'static Encoding) -> Cow<[u8]> { + let utf8 = self.as_str(); + match encoding.encode(&utf8).0 { + Cow::Borrowed(_) => { + // Recoding into UTF-8 and then back did not change anything. + Cow::from(self.raw_string_bytes()) + } + Cow::Owned(owned) => Cow::Owned(owned), + } + } +} + +impl Encoded for WithEncoding { + fn encoding(&self) -> &'static Encoding { + self.encoding + } +} + +impl Display for WithEncoding +where + T: RawStringTrait, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.as_str()) + } +} pub type OwnedEncodedString = EncodedString; pub type BorrowedEncodedString<'a> = EncodedString<&'a BorrowedRawString>; @@ -24,6 +143,12 @@ pub struct EncodedString { pub encoding: &'static Encoding, } +impl Encoded for EncodedString { + fn encoding(&self) -> &'static Encoding { + self.encoding + } +} + impl EncodedString where R: Borrow, diff --git a/rust/pspp/src/dictionary.rs b/rust/pspp/src/dictionary.rs index 0198c3bb7e..7cd40e8651 100644 --- a/rust/pspp/src/dictionary.rs +++ b/rust/pspp/src/dictionary.rs @@ -23,7 +23,7 @@ use std::{ collections::{btree_set, BTreeMap, BTreeSet, HashMap, HashSet}, fmt::{Debug, Display, Formatter, Result as FmtResult}, hash::{DefaultHasher, Hash, Hasher}, - ops::{Bound, Index, Not, RangeBounds, RangeInclusive}, + ops::{Bound, Deref, Index, Not, RangeBounds, RangeInclusive}, str::FromStr, }; @@ -40,7 +40,7 @@ use thiserror::Error as ThisError; use unicase::UniCase; use crate::{ - data::{AsEncodedString, Datum, OwnedEncodedString, OwnedRawString}, + data::{AsEncodedString, Datum, OwnedEncodedString, OwnedRawString, ResizeError}, format::{DisplayPlain, Format}, identifier::{ByIdentifier, HasIdentifier, Identifier}, output::pivot::{ @@ -1291,7 +1291,7 @@ pub struct Variable { /// `None`). /// /// Both kinds of missing values are excluded from most analyses. - pub missing_values: MissingValues, + missing_values: MissingValues, /// Output format used in most contexts. pub print_format: Format, @@ -1401,6 +1401,17 @@ impl Variable { self.width = width; } + + pub fn missing_values(&self) -> &MissingValues { + &self.missing_values + } + + pub fn missing_values_mut(&mut self) -> MissingValuesMut<'_> { + MissingValuesMut { + inner: &mut self.missing_values, + width: self.width, + } + } } impl HasIdentifier for Variable { @@ -1954,6 +1965,69 @@ impl Hash for ValueLabels { } } +pub struct MissingValuesMut<'a> { + inner: &'a mut MissingValues, + width: VarWidth, +} + +impl<'a> Deref for MissingValuesMut<'a> { + type Target = MissingValues; + + fn deref(&self) -> &Self::Target { + self.inner + } +} + +impl<'a> MissingValuesMut<'a> { + pub fn replace(&mut self, mut new: MissingValues) -> Result<(), MissingValuesError> { + new.resize(self.width)?; + *self.inner = new; + Ok(()) + } + + pub fn add_value( + &mut self, + mut value: Datum, + ) -> Result<(), MissingValuesError> { + if self.inner.values.len() > 2 + || (self.inner.range().is_some() && self.inner.values.len() > 1) + { + Err(MissingValuesError::TooMany) + } else if value.var_type() != VarType::from(self.width) { + Err(MissingValuesError::MixedTypes) + } else if value.resize(self.width).is_err() { + Err(MissingValuesError::TooWide) + } else { + value.trim_end(); + self.inner.values.push(value); + Ok(()) + } + } + + pub fn add_values( + &mut self, + values: impl IntoIterator>, + ) -> Result<(), MissingValuesError> { + let n = self.inner.values.len(); + for value in values { + self.add_value(value) + .inspect_err(|_| self.inner.values.truncate(n))?; + } + Ok(()) + } + + pub fn add_range(&mut self, range: MissingValueRange) -> Result<(), MissingValuesError> { + if self.inner.range.is_some() || self.inner.values().len() > 1 { + Err(MissingValuesError::TooMany) + } else if self.width != VarWidth::Numeric { + Err(MissingValuesError::MixedTypes) + } else { + self.inner.range = Some(range); + Ok(()) + } + } +} + #[derive(Clone, Default, Serialize)] pub struct MissingValues { /// Individual missing values, up to 3 of them. @@ -1999,6 +2073,15 @@ pub enum MissingValuesError { MixedTypes, } +impl From for MissingValuesError { + fn from(value: ResizeError) -> Self { + match value { + ResizeError::MixedTypes => MissingValuesError::MixedTypes, + ResizeError::TooWide => MissingValuesError::TooWide, + } + } +} + impl MissingValues { pub fn clear(&mut self) { *self = Self::default(); @@ -2073,10 +2156,11 @@ impl MissingValues { } } - pub fn resize(&mut self, width: VarWidth) -> Result<(), ()> { - fn inner(this: &mut MissingValues, width: VarWidth) -> Result<(), ()> { + pub fn resize(&mut self, width: VarWidth) -> Result<(), MissingValuesError> { + fn inner(this: &mut MissingValues, width: VarWidth) -> Result<(), MissingValuesError> { for datum in &mut this.values { datum.resize(width)?; + datum.trim_end(); } if let Some(range) = &mut this.range { range.resize(width)?; @@ -2126,11 +2210,11 @@ impl MissingValueRange { } } - pub fn resize(&self, width: VarWidth) -> Result<(), ()> { + pub fn resize(&self, width: VarWidth) -> Result<(), MissingValuesError> { if width.is_numeric() { Ok(()) } else { - Err(()) + Err(MissingValuesError::MixedTypes) } } } diff --git a/rust/pspp/src/sys/cooked.rs b/rust/pspp/src/sys/cooked.rs index 448f2471d5..82edb76f92 100644 --- a/rust/pspp/src/sys/cooked.rs +++ b/rust/pspp/src/sys/cooked.rs @@ -913,7 +913,9 @@ impl Records { variable.label = input.label.clone(); - variable.missing_values = input.missing_values.decode(encoding).unwrap(); + variable + .missing_values_mut() + .replace(input.missing_values.decode(encoding).unwrap()); variable.print_format = decode_format( input.print_format, @@ -1262,7 +1264,10 @@ impl Records { }) .collect::>(); match MissingValues::new(values, None) { - Ok(missing_values) => variable.missing_values = missing_values, + Ok(missing_values) => variable + .missing_values_mut() + .replace(missing_values) + .unwrap(), Err(MissingValuesError::TooWide) => { warn(Error::MissingValuesTooWide(record.var_name.clone())) } diff --git a/rust/pspp/src/sys/test.rs b/rust/pspp/src/sys/test.rs index 7fa06d650b..24463aa632 100644 --- a/rust/pspp/src/sys/test.rs +++ b/rust/pspp/src/sys/test.rs @@ -24,12 +24,11 @@ use std::{ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use encoding_rs::UTF_8; -use hexplay::HexView; use crate::{ crypto::EncryptedFile, - data::{BorrowedDatum, Datum}, - dictionary::{self, Dictionary, VarWidth, Variable}, + data::{BorrowedDatum, Datum, OwnedDatum, RawString}, + dictionary::{Dictionary, VarWidth, Variable}, endian::Endian, identifier::Identifier, output::{ @@ -573,6 +572,25 @@ fn encrypted_file_without_password() { )); } +impl WriteOptions { + /// Returns a [WriteOptions] with the given `compression` and the other + /// members set to fixed values so that running at different times or with + /// different crate names or versions won't change what's written to the + /// file. + fn reproducible(compression: Option) -> Self { + WriteOptions::new() + .with_compression(compression) + .with_timestamp(NaiveDateTime::new( + NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(), + NaiveTime::from_hms_opt(15, 7, 55).unwrap(), + )) + .with_product_name(Cow::from("PSPP TEST DATA FILE")) + .with_product_version(ProductVersion(1, 2, 3)) + } +} + +/// Tests the most basic kind of writing a system file, just writing a few +/// numeric variables and cases. #[test] fn write_numeric() { for (compression, compression_string) in [ @@ -587,14 +605,7 @@ fn write_numeric() { .add_var(Variable::new(name, VarWidth::Numeric, UTF_8)) .unwrap(); } - let mut cases = WriteOptions::new() - .with_compression(compression) - .with_timestamp(NaiveDateTime::new( - NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(), - NaiveTime::from_hms_opt(15, 7, 55).unwrap(), - )) - .with_product_name(Cow::from("PSPP TEST DATA FILE")) - .with_product_version(ProductVersion(1, 2, 3)) + let mut cases = WriteOptions::reproducible(compression) .write_writer(&dictionary, Cursor::new(Vec::new())) .unwrap(); for case in [ @@ -623,6 +634,62 @@ fn write_numeric() { } } +/// Tests writing long string value labels and missing values. +#[test] +fn write_long_string_value_labels() { + for (compression, compression_string) in [ + (None, "uncompressed"), + (Some(Compression::Simple), "simple"), + (Some(Compression::ZLib), "zlib"), + ] { + let mut dictionary = Dictionary::new(UTF_8); + let mut s1 = Variable::new(Identifier::new("s1").unwrap(), VarWidth::String(9), UTF_8); + s1.value_labels.insert( + OwnedDatum::String(RawString(String::from("abc ").into_bytes())), + String::from("First value label"), + ); + s1.value_labels.insert( + OwnedDatum::String(RawString(String::from("abcdefgh ").into_bytes())), + String::from("Second value label"), + ); + s1.value_labels.insert( + OwnedDatum::String(RawString(String::from("abcdefghi").into_bytes())), + String::from("Third value label"), + ); + s1.missing_values_mut() + .add_values([Datum::new_utf8("0")]) + .unwrap(); + dictionary.add_var(s1).unwrap(); + /* + let mut cases = WriteOptions::reproducible(compression) + .write_writer(&dictionary, Cursor::new(Vec::new())) + .unwrap(); + for case in [ + [1, 1, 1, 2], + [1, 1, 2, 30], + [1, 2, 1, 8], + [1, 2, 2, 20], + [2, 1, 1, 2], + [2, 1, 2, 22], + [2, 2, 1, 1], + [2, 2, 2, 3], + ] { + cases + .write_case( + case.into_iter() + .map(|number| BorrowedDatum::Number(Some(number as f64))), + ) + .unwrap(); + } + let sysfile = cases.finish().unwrap().unwrap().into_inner(); + let expected_filename = PathBuf::from(&format!( + "src/sys/testdata/write-numeric-{compression_string}.expected" + )); + let expected = String::from_utf8(std::fs::read(&expected_filename).unwrap()).unwrap(); + test_sysfile(Cursor::new(sysfile), &expected, &expected_filename);*/ + } +} + fn test_raw_sysfile(name: &str) { let input_filename = Path::new("src/sys/testdata") .join(name) diff --git a/rust/pspp/src/sys/write.rs b/rust/pspp/src/sys/write.rs index cdbb5e6028..a4da7e6800 100644 --- a/rust/pspp/src/sys/write.rs +++ b/rust/pspp/src/sys/write.rs @@ -276,8 +276,8 @@ where width: seg0_width.as_string_width().unwrap_or(0) as i32, has_variable_label: variable.label.is_some() as u32, missing_value_code: if variable.width.is_long_string() { - let n = variable.missing_values.values().len() as i32; - match variable.missing_values.range() { + let n = variable.missing_values().values().len() as i32; + match variable.missing_values().range() { Some(_) => -(n + 2), None => n, } @@ -300,14 +300,14 @@ where // Missing values. if !variable.width.is_long_string() { - if let Some(range) = variable.missing_values.range() { + if let Some(range) = variable.missing_values().range() { ( range.low().unwrap_or(f64::MIN), range.high().unwrap_or(f64::MAX), ) .write_le(self.writer)?; } - variable.missing_values.values().write_le(self.writer)?; + variable.missing_values().values().write_le(self.writer)?; } write_variable_continuation_records(&mut self.writer, seg0_width)?; @@ -601,21 +601,22 @@ where let mut body = Vec::new(); let mut cursor = Cursor::new(&mut body); for variable in &self.dictionary.variables { - if variable.missing_values.is_empty() || !variable.width.is_long_string() { + if variable.missing_values().is_empty() || !variable.width.is_long_string() { break; } let name = self.dictionary.encoding().encode(&variable.name).0; ( name.len() as u32, &name[..], - variable.missing_values.values().len() as u32, + variable.missing_values().values().len() as u32, 8u32, ) .write_le(&mut cursor)?; - for value in variable.missing_values.values() { - let value = value.as_string().unwrap(); - value.as_bytes()[..8].write_le(&mut cursor).unwrap(); + for value in variable.missing_values().values() { + let value = value.as_string().unwrap().as_bytes(); + let bytes = value.get(..8).unwrap_or(value); + Padded::exact(bytes, 8, b' ').write_le(&mut cursor).unwrap(); } } self.write_bytes_record(22, &body)