speed up pick_short_names
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Aug 2025 23:30:34 +0000 (16:30 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 2 Aug 2025 23:31:09 +0000 (16:31 -0700)
This made a real difference for tests, even in release mode

rust/pspp/src/dictionary.rs
rust/pspp/src/sys/mod.rs
rust/pspp/src/sys/raw/records.rs
rust/pspp/src/sys/sack.rs
rust/pspp/src/sys/test.rs
rust/pspp/src/sys/write.rs

index 2d206eb8ad9159dc257c5a05f11afcddfdf926e0..d42abfafe2ae7bad6e8e9cacd552bedb4f1b6094 100644 (file)
@@ -680,32 +680,49 @@ impl Dictionary {
     }
 
     pub fn short_names(&self) -> Vec<SmallVec<[Identifier; 1]>> {
-        fn pick_short_name(
-            variable_name: &Identifier,
-            used_names: &mut HashSet<Identifier>,
+        struct PickShortName<'a> {
+            variable_name: &'a Identifier,
+            used_names: &'a mut HashSet<Identifier>,
             encoding: &'static Encoding,
-        ) -> Identifier {
-            for index in 0.. {
-                let name = if index == 0 {
-                    variable_name.shortened(encoding)
-                } else {
-                    variable_name
-                        .with_suffix(
-                            &format!("_{}", Display26Adic::new_uppercase(index)),
-                            encoding,
-                            8,
-                        )
-                        .or_else(|_| {
-                            Identifier::new(format!("V{}", Display26Adic::new_uppercase(index)))
-                        })
-                        .unwrap()
-                };
-                if !used_names.contains(&name) {
-                    used_names.insert(name.clone());
-                    return name;
+            index: usize,
+        }
+        impl<'a> PickShortName<'a> {
+            fn new(
+                variable_name: &'a Identifier,
+                used_names: &'a mut HashSet<Identifier>,
+                encoding: &'static Encoding,
+            ) -> Self {
+                Self {
+                    variable_name,
+                    used_names,
+                    encoding,
+                    index: 0,
+                }
+            }
+
+            fn next(&mut self) -> Identifier {
+                loop {
+                    let name = if self.index == 0 {
+                        self.variable_name.shortened(self.encoding)
+                    } else {
+                        self.variable_name
+                            .with_suffix(
+                                &format!("_{}", Display26Adic::new_uppercase(self.index)),
+                                self.encoding,
+                                8,
+                            )
+                            .or_else(|_| {
+                                Identifier::new(format!("V{}", Display26Adic::new_uppercase(self.index)))
+                            })
+                            .unwrap()
+                    };
+                    if !self.used_names.contains(&name) {
+                        self.used_names.insert(name.clone());
+                        return name;
+                    }
+                    self.index += 1;
                 }
             }
-            unreachable!()
         }
 
         let mut used_names = HashSet::new();
@@ -758,21 +775,15 @@ impl Dictionary {
         // then similarly for additional segments.
         for (variable, short_names) in self.variables.iter().zip(short_names.iter_mut()) {
             if short_names[0].is_none() {
-                short_names[0] = Some(pick_short_name(
-                    &variable.name,
-                    &mut used_names,
-                    self.encoding,
-                ));
+                short_names[0] =
+                    Some(PickShortName::new(&variable.name, &mut used_names, self.encoding).next());
             }
         }
         for (variable, short_names) in self.variables.iter().zip(short_names.iter_mut()) {
+            let mut picker = PickShortName::new(&variable.name, &mut used_names, self.encoding);
             for assigned_short_name in short_names.iter_mut().skip(1) {
                 if assigned_short_name.is_none() {
-                    *assigned_short_name = Some(pick_short_name(
-                        &variable.name,
-                        &mut used_names,
-                        self.encoding,
-                    ));
+                    *assigned_short_name = Some(picker.next());
                 }
             }
         }
index 67217b4925c8f7fadbc2fe7e0bb9769cc8811d99..f660f858929ff5a8fe917e8f1062092d1525ba15 100644 (file)
@@ -26,7 +26,7 @@
 //! Use [WriteOptions] to write a system file.
 
 // Warn about missing docs, but not for items declared with `#[cfg(test)]`.
-#![cfg_attr(not(test), warn(missing_docs))]
+//#![cfg_attr(not(test), warn(missing_docs))]
 
 mod cooked;
 use binrw::Endian;
index 8753eaa2a90472c7364b79081401f318abb8bb1a..d21a8383ae762c47a69d11d9922bbb1c88da7a2d 100644 (file)
@@ -17,7 +17,7 @@ use crate::{
         Alignment, Attributes, CategoryLabels, Measure, MissingValueRange, MissingValues,
         MissingValuesError, VarType, VarWidth,
     },
-    endian::{FromBytes},
+    endian::FromBytes,
     format::{DisplayPlainF64, Format, Type},
     identifier::{Error as IdError, Identifier},
     sys::{
index 41118e47f6acad64f95ebea2629ca748d72d07ab..ab29ee837cc9295646caf0b679d6be59e207a907 100644 (file)
@@ -25,7 +25,7 @@ use std::{
     path::{Path, PathBuf},
 };
 
-use crate::endian::{ToBytes};
+use crate::endian::ToBytes;
 
 pub type Result<T, F = Error> = std::result::Result<T, F>;
 
index e59526b8b7e440ea357aeabcfb29a7d316d5bed0..b293fbdad88491c168a3d945a85d9febd5490b20 100644 (file)
@@ -15,7 +15,6 @@
 // this program.  If not, see <http://www.gnu.org/licenses/>.
 
 use std::{
-    borrow::Cow,
     fs::File,
     io::{BufRead, BufReader, Cursor, Seek},
     path::{Path, PathBuf},
@@ -23,9 +22,7 @@ use std::{
 };
 
 use binrw::Endian;
-use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
 use encoding_rs::UTF_8;
-use hexplay::HexView;
 
 use crate::{
     crypto::EncryptedFile,
@@ -40,7 +37,7 @@ use crate::{
         cooked::ReadOptions,
         raw::{self, records::Compression, ErrorDetails},
         sack::sack,
-        ProductVersion, WriteOptions,
+        WriteOptions,
     },
 };
 
@@ -573,23 +570,6 @@ fn encrypted_file_without_password() {
     ));
 }
 
-impl WriteOptions {
-    /// Returns a [WriteOptions] with the given `compression` and the other
-    /// members set to fixed values so that running at different times or with
-    /// different crate names or versions won't change what's written to the
-    /// file.
-    fn reproducible(compression: Option<Compression>) -> Self {
-        WriteOptions::new()
-            .with_compression(compression)
-            .with_timestamp(NaiveDateTime::new(
-                NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(),
-                NaiveTime::from_hms_opt(15, 7, 55).unwrap(),
-            ))
-            .with_product_name(Cow::from("PSPP TEST DATA FILE"))
-            .with_product_version(ProductVersion(1, 2, 3))
-    }
-}
-
 /// Tests the most basic kind of writing a system file, just writing a few
 /// numeric variables and cases.
 #[test]
index b5950951109afff514bfdb0976b0c40e32d9bef6..193f17769c4f699899fcd5dd0decbeef1c3f1d2e 100644 (file)
@@ -151,6 +151,23 @@ impl WriteOptions {
         let DictionaryWriter { case_vars, .. } = dict_writer;
         Writer::new(self, case_vars, writer)
     }
+
+    /// Returns a [WriteOptions] with the given `compression` and the other
+    /// members set to fixed values so that running at different times or with
+    /// different crate names or versions won't change what's written to the
+    /// file.
+    #[cfg(test)]
+    pub(super) fn reproducible(compression: Option<Compression>) -> Self {
+        use chrono::{NaiveDate, NaiveTime};
+        WriteOptions::new()
+            .with_compression(compression)
+            .with_timestamp(NaiveDateTime::new(
+                NaiveDate::from_ymd_opt(2025, 7, 30).unwrap(),
+                NaiveTime::from_hms_opt(15, 7, 55).unwrap(),
+            ))
+            .with_product_name(Cow::from("PSPP TEST DATA FILE"))
+            .with_product_version(ProductVersion(1, 2, 3))
+    }
 }
 
 struct DictionaryWriter<'a, W> {
@@ -329,6 +346,7 @@ where
                     },
                 )
                     .write_le(self.writer)?;
+                write_variable_continuation_records(&mut self.writer, width)?;
             }
         }
 
@@ -1200,3 +1218,141 @@ where
         Err(IoError::from(ErrorKind::NotSeekable))
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use std::io::Cursor;
+
+    use binrw::BinRead;
+    use encoding_rs::UTF_8;
+    use itertools::Itertools;
+
+    use crate::{
+        dictionary::{Dictionary, VarWidth, Variable},
+        identifier::Identifier,
+        sys::{
+            raw::records::{RawHeader, RawVariableRecord},
+            write::DictionaryWriter,
+            WriteOptions,
+        },
+    };
+
+    #[test]
+    fn header() {
+        for variables in [
+            (VarWidth::Numeric, 1),
+            (VarWidth::String(1), 1),
+            (VarWidth::String(8), 1),
+            (VarWidth::String(15), 2),
+            (VarWidth::String(255), 32),
+            (VarWidth::String(256), 33),
+            (VarWidth::String(20000), 79 * 32 + 12),
+        ]
+        .iter()
+        .copied()
+        .combinations_with_replacement(4)
+        {
+            let mut dictionary = Dictionary::new(UTF_8);
+            let mut expected_case_size = 0;
+            let mut weight_indexes = vec![(None, 0)];
+            for (index, (width, n_chunks)) in variables.into_iter().enumerate() {
+                let index = dictionary
+                    .add_var(Variable::new(
+                        Identifier::new(format!("v{index}")).unwrap(),
+                        width,
+                        UTF_8,
+                    ))
+                    .unwrap();
+                if width.is_numeric() {
+                    weight_indexes.push((Some(index), expected_case_size + 1));
+                }
+                expected_case_size += n_chunks;
+            }
+            for (weight_index, expected_weight_index) in weight_indexes {
+                dictionary.set_weight(weight_index).unwrap();
+
+                let mut raw = Vec::new();
+                DictionaryWriter::new(
+                    &WriteOptions::reproducible(None),
+                    &mut Cursor::new(&mut raw),
+                    &dictionary,
+                )
+                .write_header()
+                .unwrap();
+                let header = RawHeader::read_le(&mut Cursor::new(&raw)).unwrap();
+                assert_eq!(header.weight_index, expected_weight_index as u32);
+                assert_eq!(header.nominal_case_size, expected_case_size as u32);
+            }
+        }
+    }
+
+    #[test]
+    fn variables() {
+        let variables = [
+            (VarWidth::Numeric, vec![0]),
+            (VarWidth::String(1), vec![1]),
+            (VarWidth::String(8), vec![8]),
+            (VarWidth::String(15), vec![15, -1]),
+            (
+                VarWidth::String(255),
+                std::iter::once(255)
+                    .chain(std::iter::repeat_n(-1, 31))
+                    .collect(),
+            ),
+            (
+                VarWidth::String(256),
+                std::iter::once(255)
+                    .chain(std::iter::repeat_n(-1, 31))
+                    .chain(std::iter::once(4))
+                    .collect(),
+            ),
+            (
+                VarWidth::String(20000),
+                std::iter::once(255)
+                    .chain(std::iter::repeat_n(-1, 31))
+                    .cycle()
+                    .take(32 * 79)
+                    .chain(std::iter::once(92))
+                    .chain(std::iter::repeat_n(-1, 11))
+                    .collect(),
+            ),
+        ];
+        for variables in variables.iter().combinations_with_replacement(4) {
+            let mut dictionary = Dictionary::new(UTF_8);
+            for (index, (width, _)) in variables.iter().enumerate() {
+                dictionary
+                    .add_var(Variable::new(
+                        Identifier::new(format!("v{index}")).unwrap(),
+                        *width,
+                        UTF_8,
+                    ))
+                    .unwrap();
+            }
+
+            let widths = variables
+                .into_iter()
+                .map(|(_, w)| w.iter())
+                .flatten()
+                .copied();
+
+            let mut raw = Vec::new();
+            DictionaryWriter::new(
+                &WriteOptions::reproducible(None),
+                &mut Cursor::new(&mut raw),
+                &dictionary,
+            )
+            .write_variables()
+            .unwrap();
+
+            let mut cursor = Cursor::new(&raw);
+            let mut records = Vec::new();
+            while cursor.position() < raw.len() as u64 {
+                assert_eq!(u32::read_le(&mut cursor).unwrap(), 2);
+                records.push(RawVariableRecord::read_le(&mut cursor).unwrap());
+            }
+            for (record, expected_width) in records.iter().zip_eq(widths.into_iter()) {
+                assert_eq!(record.width, expected_width);
+            }
+        }
+    }
+}