work on writing sav files
authorBen Pfaff <blp@cs.stanford.edu>
Sat, 19 Jul 2025 00:34:03 +0000 (17:34 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sat, 19 Jul 2025 00:34:03 +0000 (17:34 -0700)
rust/pspp/src/dictionary.rs
rust/pspp/src/output/spv.rs
rust/pspp/src/sys/raw/records.rs
rust/pspp/src/sys/write.rs

index 7e1bb237053f7b4f168637ce592c1dc32f7f1c58..4a4e9f9d9e6458486c2006c31069893248017efe 100644 (file)
@@ -22,7 +22,7 @@ use std::{
     cmp::Ordering,
     collections::{BTreeMap, BTreeSet, HashMap, HashSet},
     fmt::{Debug, Display, Formatter, Result as FmtResult},
-    hash::Hash,
+    hash::{DefaultHasher, Hash, Hasher},
     ops::{Bound, Not, RangeBounds, RangeInclusive},
     str::FromStr,
 };
@@ -182,7 +182,7 @@ impl VarWidth {
 
     /// Number of bytes per segment by which the amount of space for very long
     /// string variables is allocated.
-    const EFFECTIVE_VLS_CHUNK: usize = 252;
+    pub const SEGMENT_SIZE: usize = 252;
 
     /// Returns the number of "segments" used for writing case data for a
     /// variable with this width.  A segment is a physical variable in the
@@ -191,14 +191,24 @@ impl VarWidth {
     /// segment.
     pub fn n_segments(&self) -> usize {
         if self.is_very_long() {
-            self.as_string_width()
-                .unwrap()
-                .div_ceil(Self::EFFECTIVE_VLS_CHUNK)
+            self.as_string_width().unwrap().div_ceil(Self::SEGMENT_SIZE)
         } else {
             1
         }
     }
 
+    /// Returns the number of 8-byte chunks used for writing case data for a
+    /// variable with this width.  This concept does not apply to very long
+    /// string variables, which are divided into [multiple
+    /// segments](Self::n_segments) (which in turn are divided into chunks).
+    pub fn n_chunks(&self) -> Option<usize> {
+        match *self {
+            VarWidth::Numeric => Some(1),
+            VarWidth::String(w) if w <= 255 => Some(w.div_ceil(8) as usize),
+            VarWidth::String(_) => None,
+        }
+    }
+
     /// Returns the width to allocate to the segment with the given
     /// `segment_idx` within this variable.  A segment is a physical variable in
     /// the system file that represents some piece of a logical variable as seen
@@ -210,7 +220,7 @@ impl VarWidth {
         if segment_idx < self.n_segments() - 1 {
             255
         } else {
-            self.as_string_width().unwrap() - segment_idx * Self::EFFECTIVE_VLS_CHUNK
+            self.as_string_width().unwrap() - segment_idx * Self::SEGMENT_SIZE
         }
     }
 
@@ -1360,7 +1370,7 @@ impl VariableSet {
     }
 }
 
-#[derive(Clone, Debug, Default)]
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
 pub struct ValueLabels(pub HashMap<Datum, String>);
 
 impl ValueLabels {
@@ -1396,6 +1406,19 @@ impl ValueLabels {
     }
 }
 
+impl Hash for ValueLabels {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let mut hash = 0;
+        for (k, v) in &self.0 {
+            let mut hasher = DefaultHasher::new();
+            k.hash(&mut hasher);
+            v.hash(&mut hasher);
+            hash ^= hasher.finish();
+        }
+        state.write_u64(hash);
+    }
+}
+
 #[derive(Clone, Default)]
 pub struct MissingValues {
     /// Individual missing values, up to 3 of them.
@@ -1423,6 +1446,14 @@ pub enum MissingValuesError {
 }
 
 impl MissingValues {
+    pub fn values(&self) -> &[Datum] {
+        &self.values
+    }
+
+    pub fn range(&self) -> Option<&MissingValueRange> {
+        self.range.as_ref()
+    }
+
     pub fn new(
         mut values: Vec<Datum>,
         range: Option<MissingValueRange>,
index f34090fbcacce2e485e48106b4ee7ac39b69803d..43ed553d6dc09ee258489faf51cab5f9e4125efd 100644 (file)
@@ -899,7 +899,7 @@ where
     }
 }
 
-struct Zeros(usize);
+pub struct Zeros(pub usize);
 
 impl BinWrite for Zeros {
     type Args<'a> = ();
index 7dbe107959d7b9b5e95a404da33cf8dde0505163..606bb4ef21cf5bba44967b1ca9289dee0634c3e2 100644 (file)
@@ -18,6 +18,7 @@ use crate::{
         VarWidth,
     },
     endian::{Endian, Parse},
+    format::{Format, Type},
     identifier::{Error as IdError, Identifier},
     sys::raw::{
         read_bytes, read_string, read_vec, Decoder, Error, ErrorDetails, Magic, RawDatum,
@@ -127,6 +128,7 @@ where
     }
 }
 
+#[allow(missing_docs)]
 #[derive(BinRead, BinWrite)]
 pub struct RawHeader {
     pub magic: [u8; 4],
@@ -253,13 +255,33 @@ impl FileHeader<RawString> {
 }
 
 /// [Format](crate::format::Format) as represented in a system file.
-#[derive(Copy, Clone, PartialEq, Eq, Hash)]
+#[derive(Copy, Clone, PartialEq, Eq, Hash, BinRead, BinWrite)]
 pub struct RawFormat(
     /// The most-significant 16 bits are the type, the next 8 bytes are the
     /// width, and the least-significant 8 bits are the number of decimals.
     pub u32,
 );
 
+/// Cannot convert very long string (wider than 255 bytes) to [RawFormat].
+#[derive(Copy, Clone, Debug)]
+pub struct VeryLongStringError;
+
+impl TryFrom<Format> for RawFormat {
+    type Error = VeryLongStringError;
+
+    fn try_from(value: Format) -> Result<Self, Self::Error> {
+        let type_ = u16::from(value.type_()) as u32;
+        let w = match value.var_width() {
+            VarWidth::Numeric => value.w() as u8,
+            VarWidth::String(w) if w > 255 => return Err(VeryLongStringError),
+            VarWidth::String(w) if value.type_() == Type::AHex => (w * 2).min(255) as u8,
+            VarWidth::String(w) => w as u8,
+        } as u32;
+        let d = value.d() as u32;
+        Ok(Self((type_ << 16) | (w << 8) | d))
+    }
+}
+
 impl Debug for RawFormat {
     fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
         let type_ = format_name(self.0 >> 16);
@@ -469,6 +491,17 @@ where
     }
 }
 
+#[allow(missing_docs)]
+#[derive(BinRead, BinWrite)]
+pub struct RawVariableRecord {
+    pub width: i32,
+    pub has_variable_label: u32,
+    pub missing_value_code: i32,
+    pub print_format: RawFormat,
+    pub write_format: RawFormat,
+    pub name: [u8; 8],
+}
+
 impl VariableRecord<RawString> {
     /// Reads a variable record from `r`.
     pub fn read<R>(
@@ -479,16 +512,6 @@ impl VariableRecord<RawString> {
     where
         R: Read + Seek,
     {
-        #[derive(BinRead)]
-        struct RawVariableRecord {
-            width: i32,
-            has_variable_label: u32,
-            missing_value_code: i32,
-            print_format: u32,
-            write_format: u32,
-            name: [u8; 8],
-        }
-
         let start_offset = r.stream_position()?;
         let offsets = start_offset..start_offset + 28;
         let raw_record =
@@ -538,8 +561,8 @@ impl VariableRecord<RawString> {
             offsets: start_offset..end_offset,
             width,
             name: RawString(raw_record.name.into()),
-            print_format: RawFormat(raw_record.print_format),
-            write_format: RawFormat(raw_record.write_format),
+            print_format: raw_record.print_format,
+            write_format: raw_record.write_format,
             missing_values,
             label,
         }))
index e65edfd29e59737884207509ac7268dda0879399..83d5663737f1fc0b9c900e055da5012f4bdb251f 100644 (file)
@@ -1,16 +1,23 @@
+#![allow(dead_code, missing_docs)]
+use core::f64;
 use std::{
+    collections::HashMap,
     io::{Seek, Write},
-    iter::repeat_n,
 };
 
-use binrw::{BinWrite, Error as BinError};
+use binrw::{BinWrite, Endian, Error as BinError};
 use chrono::Local;
+use encoding_rs::Encoding;
 use smallvec::SmallVec;
 
 use crate::{
-    dictionary::{Dictionary, VarWidth},
+    data::Datum,
+    dictionary::{Dictionary, ValueLabels, VarWidth},
+    format::Format,
+    identifier::Identifier,
+    output::spv::Zeros,
     sys::raw::{
-        records::{Compression, RawHeader},
+        records::{Compression, RawFormat, RawHeader, RawVariableRecord},
         Magic,
     },
 };
@@ -44,18 +51,41 @@ impl Default for WriteOptions {
     }
 }
 
-impl WriteOptions {
-    pub fn new() -> Self {
-        Self::default()
+struct DictionaryWriter<'a, W> {
+    compression: Option<Compression>,
+    version: Version,
+    short_names: Vec<SmallVec<[Identifier; 1]>>,
+    case_vars: Vec<CaseVar>,
+    writer: &'a mut W,
+    dictionary: &'a Dictionary,
+}
+
+impl<'a, W> DictionaryWriter<'a, W>
+where
+    W: Write + Seek,
+{
+    pub fn new(options: &WriteOptions, writer: &'a mut W, dictionary: &'a Dictionary) -> Self {
+        Self {
+            compression: options.compression,
+            version: options.version,
+            short_names: dictionary.short_names(),
+            case_vars: dictionary
+                .variables
+                .iter()
+                .map(|variable| CaseVar::new(variable.width))
+                .collect::<Vec<_>>(),
+            writer,
+            dictionary,
+        }
     }
-    pub fn write_writer<W>(
-        self,
-        dictionary: &Dictionary,
-        mut writer: W,
-    ) -> Result<Writer<W>, BinError>
-    where
-        W: Write + Seek,
-    {
+
+    pub fn write(&mut self) -> Result<(), BinError> {
+        self.write_header()?;
+        self.write_variables()?;
+        self.write_value_labels()
+    }
+
+    fn write_header(&mut self) -> Result<(), BinError> {
         fn as_byte_array<const N: usize>(s: String) -> [u8; N] {
             let mut bytes = s.into_bytes();
             bytes.resize(N, b' ');
@@ -66,12 +96,6 @@ impl WriteOptions {
             case_vars.iter().map(CaseVar::n_segments).sum::<usize>() as u32
         }
 
-        let case_vars = dictionary
-            .variables
-            .iter()
-            .map(|variable| CaseVar::new(variable.width))
-            .collect::<Vec<_>>();
-
         let now = Local::now();
         let header = RawHeader {
             magic: if self.compression == Some(Compression::ZLib) {
@@ -87,14 +111,14 @@ impl WriteOptions {
                 ))
             },
             layout_code: 2,
-            nominal_case_size: count_segments(&case_vars),
+            nominal_case_size: count_segments(&self.case_vars),
             compression_code: match self.compression {
                 Some(Compression::Simple) => 1,
                 Some(Compression::ZLib) => 2,
                 None => 0,
             },
-            weight_index: if let Some(weight_index) = dictionary.weight {
-                count_segments(&case_vars[..weight_index]) + 1
+            weight_index: if let Some(weight_index) = self.dictionary.weight {
+                count_segments(&self.case_vars[..weight_index]) + 1
             } else {
                 0
             },
@@ -102,11 +126,243 @@ impl WriteOptions {
             bias: 100.0,
             creation_date: as_byte_array(now.format("%d %b %Y").to_string()),
             creation_time: as_byte_array(now.format("%H:%M:%S").to_string()),
-            file_label: as_byte_array(dictionary.file_label.clone().unwrap_or_default()),
+            file_label: as_byte_array(self.dictionary.file_label.clone().unwrap_or_default()),
         };
-        header.write_le(&mut writer)?;
+        header.write_le(self.writer)
+    }
+
+    fn write_variables(&mut self) -> Result<(), BinError> {
+        for (variable, short_names) in self
+            .dictionary
+            .variables
+            .iter()
+            .zip(self.short_names.iter())
+        {
+            let mut segment_widths = SegmentWidths::new(variable.width);
+            let mut short_names = short_names.iter();
+            let seg0_width = segment_widths.next().unwrap();
+            let name0 = short_names.next().unwrap();
+            let record = RawVariableRecord {
+                width: seg0_width.as_string_width().unwrap_or(0) as i32,
+                has_variable_label: variable.label.is_some() as u32,
+                missing_value_code: if variable.width.is_long_string() {
+                    let n = variable.missing_values.values().len() as i32;
+                    match variable.missing_values.range() {
+                        Some(_) => -(n + 2),
+                        None => n,
+                    }
+                } else {
+                    0
+                },
+                print_format: to_raw_format(variable.print_format, seg0_width),
+                write_format: to_raw_format(variable.write_format, seg0_width),
+                name: encode_fixed_string(name0, variable.encoding),
+            };
+            (2u32, record).write_le(self.writer)?;
+
+            // Variable label.
+            if let Some(label) = variable.label() {
+                let label = variable.encoding.encode(&label).0;
+                let len = label.len().min(255) as u32;
+                let padded_len = len.next_multiple_of(4);
+                (len, &*label, Zeros((padded_len - len) as usize)).write_le(self.writer)?;
+            }
+
+            // Missing values.
+            if !variable.width.is_long_string() {
+                if let Some(range) = variable.missing_values.range() {
+                    (
+                        range.low().unwrap_or(-f64::MAX),
+                        range.high().unwrap_or(f64::MAX),
+                    )
+                        .write_le(self.writer)?;
+                }
+                variable.missing_values.values().write_le(self.writer)?;
+            }
+            write_variable_continuation_records(&mut self.writer, seg0_width)?;
+
+            // Write additional segments for very long string variables.
+            for (width, name) in segment_widths.zip(short_names) {
+                let format: RawFormat = Format::default_for_width(width).try_into().unwrap();
+                (
+                    2u32,
+                    RawVariableRecord {
+                        width: width.as_string_width().unwrap() as i32,
+                        has_variable_label: 0,
+                        missing_value_code: 0,
+                        print_format: format,
+                        write_format: format,
+                        name: encode_fixed_string(name, variable.encoding),
+                    },
+                )
+                    .write_le(self.writer)?;
+            }
+        }
+
+        fn write_variable_continuation_records<W>(
+            mut writer: W,
+            width: VarWidth,
+        ) -> Result<(), BinError>
+        where
+            W: Write + Seek,
+        {
+            let continuation = (
+                2u32,
+                RawVariableRecord {
+                    width: -1,
+                    has_variable_label: 0,
+                    missing_value_code: 0,
+                    print_format: RawFormat(0),
+                    write_format: RawFormat(0),
+                    name: [0; 8],
+                },
+            );
+            for _ in 1..width.n_chunks().unwrap() {
+                continuation.write_le(&mut writer)?;
+            }
+            Ok(())
+        }
+
+        fn encode_fixed_string<const N: usize>(s: &str, encoding: &'static Encoding) -> [u8; N] {
+            let mut encoded = encoding.encode(s).0.into_owned();
+            encoded.resize(N, b' ');
+            encoded.try_into().unwrap()
+        }
+        fn to_raw_format(mut format: Format, width: VarWidth) -> RawFormat {
+            format.resize(width);
+            RawFormat::try_from(format).unwrap()
+        }
+
+        Ok(())
+    }
+
+    /// Writes value label records, except for long string variables.
+    fn write_value_labels(&mut self) -> Result<(), BinError> {
+        // Collect identical sets of value labels.
+        let mut sets = HashMap::<&ValueLabels, Vec<_>>::new();
+        let mut index = 1usize;
+        for variable in &self.dictionary.variables {
+            if !variable.width.is_long_string() && !variable.value_labels.is_empty() {
+                sets.entry(&variable.value_labels)
+                    .or_default()
+                    .push(index as u32);
+            }
+            index += SegmentWidths::new(variable.width)
+                .map(|w| w.n_chunks().unwrap())
+                .sum::<usize>();
+        }
+
+        for (value_labels, variables) in sets {
+            // Label record.
+            (3u32, value_labels.0.len() as u32).write_le(self.writer)?;
+            for (datum, label) in &value_labels.0 {
+                let label = &*self.dictionary.encoding.encode(&label).0;
+                let padding = label.len().next_multiple_of(8) - label.len();
+                (datum, label.len() as u32, label, Zeros(padding)).write_le(self.writer)?;
+            }
+
+            // Variable record.
+            (4u32, variables.len() as u32, variables).write_le(self.writer)?;
+        }
         todo!()
     }
+
+    pub fn write_documents(&mut self) -> Result<(), BinError> {
+        if !self.dictionary.documents.is_empty() {
+            (6u32, self.dictionary.documents.len() as u32).write_le(self.writer)?;
+            for line in &self.dictionary.documents {
+                Padded::exact(&*self.dictionary.encoding.encode(&line).0, 80, b' ')
+                    .write_le(self.writer)?;
+            }
+        }
+        Ok(())
+    }
+}
+
+#[derive(BinWrite)]
+struct Padded<'a> {
+    bytes: &'a [u8],
+    padding: Pad,
+}
+
+impl<'a> Padded<'a> {
+    pub fn exact(bytes: &'a [u8], length: usize, pad: u8) -> Self {
+        let min = bytes.len().min(length);
+        Self {
+            bytes: &bytes[..min],
+            padding: Pad::new(length - min, pad),
+        }
+    }
+
+    pub fn to_multiple(bytes: &'a [u8], multiple: usize, pad: u8) -> Self {
+        let length = bytes.len().next_multiple_of(multiple);
+        Self {
+            padding: Pad::new(length - bytes.len(), pad),
+            bytes,
+        }
+    }
+}
+
+pub struct Pad {
+    n: usize,
+    pad: u8,
+}
+
+impl Pad {
+    pub fn new(n: usize, pad: u8) -> Self {
+        Self { n, pad }
+    }
+}
+
+impl BinWrite for Pad {
+    type Args<'a> = ();
+
+    fn write_options<W: Write + Seek>(
+        &self,
+        writer: &mut W,
+        _endian: Endian,
+        _args: Self::Args<'_>,
+    ) -> binrw::BinResult<()> {
+        for _ in 0..self.n {
+            writer.write_all(&[self.pad])?;
+        }
+        Ok(())
+    }
+}
+
+impl WriteOptions {
+    pub fn new() -> Self {
+        Self::default()
+    }
+    pub fn write_writer<W>(
+        self,
+        dictionary: &Dictionary,
+        mut writer: W,
+    ) -> Result<Writer<W>, BinError>
+    where
+        W: Write + Seek,
+    {
+        DictionaryWriter::new(&self, &mut writer, dictionary).write()?;
+        todo!()
+    }
+}
+
+impl BinWrite for Datum {
+    type Args<'a> = ();
+
+    fn write_options<W: Write + Seek>(
+        &self,
+        writer: &mut W,
+        endian: binrw::Endian,
+        _: (),
+    ) -> binrw::BinResult<()> {
+        match self {
+            Datum::Number(number) => number
+                .unwrap_or(-f64::MAX)
+                .write_options(writer, endian, ()),
+            Datum::String(raw_string) => raw_string.0.write_options(writer, endian, ()),
+        }
+    }
 }
 
 #[derive(Debug)]
@@ -115,15 +371,39 @@ struct StringSegment {
     padding_bytes: usize,
 }
 
-fn segment_widths(width: usize) -> impl Iterator<Item = usize> {
-    let n_segments = width.div_ceil(252);
-    repeat_n(255, n_segments - 1)
-        .chain(if n_segments > 1 {
-            std::iter::once(width - (n_segments - 1) * 252)
+struct SegmentWidths {
+    width: VarWidth,
+    i: usize,
+    n: usize,
+}
+impl SegmentWidths {
+    pub fn new(width: VarWidth) -> Self {
+        Self {
+            width,
+            i: 0,
+            n: width.n_segments(),
+        }
+    }
+}
+
+impl Iterator for SegmentWidths {
+    type Item = VarWidth;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let i = self.i;
+        if i >= self.n {
+            None
         } else {
-            std::iter::once(width)
-        })
-        .map(|w| w.next_multiple_of(8))
+            self.i += 1;
+            match self.width {
+                VarWidth::Numeric => Some(VarWidth::Numeric),
+                VarWidth::String(_) if i < self.n - 1 => Some(VarWidth::String(255)),
+                VarWidth::String(width) => Some(VarWidth::String(
+                    width - (self.n as u16 - 1) * VarWidth::SEGMENT_SIZE as u16,
+                )),
+            }
+        }
+    }
 }
 
 enum CaseVar {
@@ -138,11 +418,11 @@ impl CaseVar {
     fn new(width: VarWidth) -> Self {
         match width {
             VarWidth::Numeric => Self::Numeric,
-            VarWidth::String(width) => {
-                let width = width as usize;
+            VarWidth::String(w) => {
                 let mut encoding = SmallVec::<[StringSegment; 1]>::new();
-                let mut remaining = width;
-                for segment in segment_widths(width) {
+                let mut remaining = w as usize;
+                for segment in SegmentWidths::new(width) {
+                    let segment = segment.as_string_width().unwrap().next_multiple_of(8);
                     let data_bytes = remaining.min(segment).min(255);
                     let padding_bytes = segment - data_bytes;
                     if data_bytes > 0 {
@@ -155,7 +435,10 @@ impl CaseVar {
                         encoding.last_mut().unwrap().padding_bytes += padding_bytes;
                     }
                 }
-                CaseVar::String { width, encoding }
+                CaseVar::String {
+                    width: w as usize,
+                    encoding,
+                }
             }
         }
     }