Change Axis from holding Arcs to dimensions to indexes.
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 6 Apr 2025 15:49:53 +0000 (08:49 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 6 Apr 2025 15:49:53 +0000 (08:49 -0700)
rust/pspp/src/output/pivot/mod.rs
rust/pspp/src/output/pivot/output.rs

index ec6f68c797c99378c95ee297bca093c28e251fcc..6c17503f66435409a679aa075e8132d6306bcbdb 100644 (file)
@@ -59,7 +59,7 @@ use std::{
     collections::HashMap,
     fmt::{Debug, Display, Write},
     io::Read,
-    iter::{once, repeat, repeat_n},
+    iter::{once, repeat, repeat_n, FusedIterator},
     ops::{Index, IndexMut, Not, Range, RangeInclusive},
     str::{from_utf8, FromStr, Utf8Error},
     sync::{Arc, OnceLock, Weak},
@@ -76,7 +76,7 @@ use look_xml::TableProperties;
 use quick_xml::{de::from_str, DeError};
 use serde::{de::Visitor, Deserialize};
 use smallstr::SmallString;
-use smallvec::{smallvec, SmallVec};
+use smallvec::SmallVec;
 use thiserror::Error as ThisError;
 use tlo::parse_tlo;
 
@@ -250,65 +250,65 @@ impl Axis3 {
     }
 }
 
+impl From<Axis2> for Axis3 {
+    fn from(axis2: Axis2) -> Self {
+        match axis2 {
+            Axis2::X => Self::X,
+            Axis2::Y => Self::Y,
+        }
+    }
+}
+
 /// An axis within a pivot table.
 #[derive(Clone, Debug, Default)]
 pub struct Axis {
     /// `dimensions[0]` is the innermost dimension.
-    dimensions: Vec<Arc<Dimension>>,
-
-    /// The number of rows or columns along the axis, that is, the product of
-    /// `dimensions[*].len()`.  It is 0 if any dimension has 0 leaves.
-    extent: usize,
+    dimensions: Vec<usize>,
 
     /// Sum of `dimensions[*].label_depth`.
     label_depth: usize,
 }
 
-pub struct AxisIterator<'a> {
-    axis: &'a Axis,
+pub struct AxisIterator {
     indexes: SmallVec<[usize; 4]>,
+    lengths: SmallVec<[usize; 4]>,
     done: bool,
 }
 
-impl<'a> Iterator for AxisIterator<'a> {
+impl FusedIterator for AxisIterator {}
+impl Iterator for AxisIterator {
     type Item = SmallVec<[usize; 4]>;
 
     fn next(&mut self) -> Option<Self::Item> {
-        if self.indexes.is_empty() {
-            if self.done
-                || self
-                    .axis
-                    .dimensions
-                    .iter()
-                    .any(|dimension| dimension.is_empty())
-            {
-                None
-            } else {
-                self.done = true;
-                self.indexes = smallvec![0; self.axis.dimensions.len()];
-                Some(self.indexes.clone())
-            }
+        if self.done {
+            None
         } else {
-            for (index, dimension) in self.indexes.iter_mut().zip(self.axis.dimensions.iter()) {
+            let retval = self.indexes.clone();
+            for (index, len) in self.indexes.iter_mut().zip(self.lengths.iter().copied()) {
                 *index += 1;
-                if *index < dimension.len() {
-                    return Some(self.indexes.clone());
+                if *index < len {
+                    return Some(retval);
                 };
-                *index = 0
+                *index = 0;
             }
-            None
+            self.done = true;
+            Some(retval)
         }
     }
 }
 
-impl Axis {
-    fn iter(&self) -> AxisIterator {
+impl PivotTable {
+    fn axis_values(&self, axis: Axis3) -> AxisIterator {
         AxisIterator {
-            axis: self,
-            indexes: SmallVec::new(),
-            done: false,
+            indexes: repeat_n(0, self.axes[axis].dimensions.len()).collect(),
+            lengths: self.axis_dimensions(axis).map(|d| d.len()).collect(),
+            done: self.axis_extent(axis) == 0,
         }
     }
+
+    fn axis_extent(&self, axis: Axis3) -> usize {
+        self.axis_dimensions(axis).map(|d| d.len()).product()
+    }
 }
 
 /// Dimensions.
@@ -662,7 +662,10 @@ impl PivotTableBuilder {
         let corner_text = false;
         let mut table = PivotTable::new(self.title, self.look.clone());
         let mut dimensions = Vec::with_capacity(self.dimensions.len());
-        let mut axes = EnumMap::from_fn(|_key| Vec::with_capacity(self.dimensions.len()));
+        let mut axes = EnumMap::from_fn(|_axis| Axis {
+            dimensions: Vec::with_capacity(self.dimensions.len()),
+            label_depth: 0,
+        });
         for (top_index, d) in self.dimensions.into_iter().enumerate() {
             let axis = d.axis;
             let label_position = if axis == Axis3::Y && !corner_text {
@@ -670,20 +673,13 @@ impl PivotTableBuilder {
             } else {
                 LabelPosition::Nested
             };
-            let d = Arc::new(d.build(axes[axis].len(), top_index, label_position));
-            axes[axis].push(d.clone());
+            let d = d.build(axes[axis].dimensions.len(), top_index, label_position);
+            axes[axis].dimensions.push(dimensions.len());
+            axes[axis].label_depth += d.label_depth();
             dimensions.push(d);
         }
         table.dimensions = dimensions;
-        table.axes = axes.map(|_axis, dimensions| {
-            let label_depth = dimensions.iter().map(|d| d.label_depth()).sum();
-            let extent = dimensions.iter().map(|d| d.data_leaves.len()).product();
-            Axis {
-                dimensions,
-                extent,
-                label_depth,
-            }
-        });
+        table.axes = axes;
         table.cells = self.cells;
         table.current_layer = repeat_n(0, table.axes[Axis3::Z].dimensions.len()).collect();
         table
@@ -1522,7 +1518,7 @@ pub struct PivotTable {
     pub corner_text: Option<Box<Value>>,
     pub caption: Option<Box<Value>>,
     pub notes: Option<String>,
-    pub dimensions: Vec<Arc<Dimension>>,
+    pub dimensions: Vec<Dimension>,
     pub axes: EnumMap<Axis3, Axis>,
     pub cells: HashMap<usize, Value>,
 }
@@ -1603,10 +1599,9 @@ impl PivotTable {
         presentation_indexes: EnumMap<Axis3, &[usize]>,
     ) -> SmallVec<[usize; 4]> {
         let mut data_indexes = SmallVec::from_elem(0, self.dimensions.len());
-        for i in enum_iterator::all::<Axis3>() {
-            let axis = &self.axes[i];
-            for (j, dimension) in axis.dimensions.iter().enumerate() {
-                let pindex = presentation_indexes[i][j];
+        for axis in enum_iterator::all::<Axis3>() {
+            for (i, dimension) in self.axis_dimensions(axis).enumerate() {
+                let pindex = presentation_indexes[axis][i];
                 data_indexes[dimension.top_index] =
                     dimension.presentation_leaves[pindex].data_index;
             }
@@ -1622,7 +1617,7 @@ impl PivotTable {
     /// - Otherwise, the iterator will just visit `self.current_layer`.
     pub fn layers(&self, print: bool) -> Box<dyn Iterator<Item = SmallVec<[usize; 4]>> + '_> {
         if print && self.look.print_all_layers {
-            Box::new(self.axes[Axis3::Z].iter())
+            Box::new(self.axis_values(Axis3::Z))
         } else {
             Box::new(once(SmallVec::from_slice(&self.current_layer)))
         }
@@ -1640,6 +1635,61 @@ impl PivotTable {
     pub fn transpose(&mut self) {
         self.axes.swap(Axis3::X, Axis3::Y);
     }
+
+    fn axis_dimensions(
+        &self,
+        axis: Axis3,
+    ) -> impl Iterator<Item = &Dimension> + DoubleEndedIterator + ExactSizeIterator {
+        self.axes[axis]
+            .dimensions
+            .iter()
+            .map(|index| &self.dimensions[*index])
+    }
+
+    fn find_dimension(&self, dim_index: usize) -> Option<(Axis3, usize)> {
+        debug_assert!(dim_index < self.dimensions.len());
+        for axis in enum_iterator::all::<Axis3>() {
+            for (position, dimension) in self.axes[axis].dimensions.iter().copied().enumerate() {
+                if dimension == dim_index {
+                    return Some((axis, position));
+                }
+            }
+        }
+        None
+    }
+    pub fn move_dimension(&mut self, dim_index: usize, new_axis: Axis3, new_position: usize) {
+        let (old_axis, old_position) = self.find_dimension(dim_index).unwrap();
+        if old_axis == new_axis && old_position == new_position {
+            return;
+        }
+
+        // Update the current layer, if necessary.  If we're moving within the
+        // layer axis, preserve the current layer.
+        match (old_axis, new_axis) {
+            (Axis3::Z, Axis3::Z) => {
+                // Rearrange the layer axis.
+                if old_position < new_position {
+                    self.current_layer[old_position..=new_position].rotate_left(1);
+                } else {
+                    self.current_layer[new_position..=old_position].rotate_right(1);
+                }
+            }
+            (Axis3::Z, _) => {
+                // A layer is becoming a row or column.
+                self.current_layer.remove(old_position);
+            }
+            (_, Axis3::Z) => {
+                // A row or column is becoming a layer.
+                self.current_layer.insert(new_position, 0);
+            }
+            _ => (),
+        }
+
+        self.axes[old_axis].dimensions.remove(old_position);
+        self.axes[new_axis]
+            .dimensions
+            .insert(new_position, dim_index);
+    }
 }
 
 pub struct Layers {}
index 536ae3f9b530940f05d3209d3f390ab988328c3f..52990c26ba889ec9031b8ba36314fff5dd7fbc8c 100644 (file)
@@ -10,9 +10,8 @@ use crate::output::{
 };
 
 use super::{
-    Area, AsValueOptions, Axis, Axis2, Axis3, Border, BorderStyle, BoxBorder, Category,
-    CategoryTrait, Color, Coord2, Dimension, Footnote, PivotTable, Rect2, RowColBorder, Stroke,
-    Value,
+    Area, AsValueOptions, Axis2, Axis3, Border, BorderStyle, BoxBorder, Category, CategoryTrait,
+    Color, Coord2, Dimension, Footnote, PivotTable, Rect2, RowColBorder, Stroke, Value,
 };
 
 /// All of the combinations of dimensions along an axis.
@@ -72,7 +71,7 @@ impl PivotTable {
         fixed_axis: Axis3,
     ) -> bool {
         let vary_axis = fixed_axis.transpose().unwrap();
-        for vary_indexes in self.axes[vary_axis].iter() {
+        for vary_indexes in self.axis_values(vary_axis) {
             let mut presentation_indexes = enum_map! {
                 Axis3::Z => layer_indexes,
                 _ => fixed_indexes,
@@ -92,15 +91,16 @@ impl PivotTable {
         omit_empty: bool,
     ) -> AxisEnumeration {
         let axis = &self.axes[enum_axis];
+        let extent = self.axis_extent(enum_axis);
         let indexes = if axis.dimensions.is_empty() {
             vec![0]
-        } else if axis.extent == 0 {
+        } else if extent == 0 {
             vec![]
         } else {
             let mut enumeration =
-                Vec::with_capacity(axis.extent.checked_mul(axis.dimensions.len()).unwrap());
+                Vec::with_capacity(extent.checked_mul(axis.dimensions.len()).unwrap());
             if omit_empty {
-                for axis_indexes in axis.iter() {
+                for axis_indexes in self.axis_values(enum_axis) {
                     if !self.is_row_empty(layer_indexes, &axis_indexes, enum_axis) {
                         enumeration.extend_from_slice(&axis_indexes);
                     }
@@ -108,7 +108,7 @@ impl PivotTable {
             }
 
             if enumeration.is_empty() {
-                for axis_indexes in axis.iter() {
+                for axis_indexes in self.axis_values(enum_axis) {
                     enumeration.extend_from_slice(&axis_indexes);
                 }
             }
@@ -174,10 +174,9 @@ impl PivotTable {
             self.as_value_options(),
         );
         compose_headings(
+            self,
             &mut body,
-            &self.axes[Axis3::X],
             Axis2::X,
-            &self.axes[Axis3::Y],
             &column_enumeration,
             RowColBorder::ColHorz,
             RowColBorder::ColVert,
@@ -186,10 +185,9 @@ impl PivotTable {
             Area::ColumnLabels,
         );
         compose_headings(
+            self,
             &mut body,
-            &self.axes[Axis3::Y],
             Axis2::Y,
-            &self.axes[Axis3::X],
             &row_enumeration,
             RowColBorder::RowVert,
             RowColBorder::RowHorz,
@@ -307,11 +305,12 @@ impl PivotTable {
         }
     }
 
-    fn nonempty_layer_dimensions(&self) -> impl Iterator<Item = &Arc<Dimension>> {
+    fn nonempty_layer_dimensions(&self) -> impl Iterator<Item = &Dimension> {
         self.axes[Axis3::Z]
             .dimensions
             .iter()
             .rev()
+            .map(|index| &self.dimensions[*index])
             .filter(|d| !d.data_leaves.is_empty())
     }
 
@@ -365,10 +364,9 @@ fn find_category<'a>(
 /// instead uses 'h', which is set to H for column headings and V for row
 /// headings.
 fn compose_headings(
+    pt: &PivotTable,
     table: &mut Table,
-    h_axis: &Axis,
     h: Axis2,
-    v_axis: &Axis,
     column_enumeration: &AxisEnumeration,
     col_horz: RowColBorder,
     col_vert: RowColBorder,
@@ -377,6 +375,8 @@ fn compose_headings(
     area: Area,
 ) {
     let v = !h;
+    let h_axis = &pt.axes[h.into()];
+    let v_axis = &pt.axes[v.into()];
     let v_size = h_axis.label_depth;
     let h_ofs = v_axis.label_depth;
     let n_columns = column_enumeration.len();
@@ -443,9 +443,8 @@ fn compose_headings(
     vrules[0] = true;
     vrules[n_columns] = true;
 
-    for (dim_index, d) in h_axis
-        .dimensions
-        .iter()
+    for (dim_index, d) in pt
+        .axis_dimensions(h.into())
         .enumerate()
         .rev()
         .filter(|(_, d)| !d.hide_all_labels)