e9eca118abe71218f8122e70629eccfcf820500f
[pspp] / rust / src / dictionary.rs
1 use std::{
2     collections::{HashMap, HashSet},
3     fmt::Debug,
4     ops::{Bound, RangeBounds},
5 };
6
7 use encoding_rs::Encoding;
8 use indexmap::IndexSet;
9
10 use crate::{
11     cooked::{Value, VarWidth},
12     format::Spec,
13     identifier::{ByIdentifier, HasIdentifier, Identifier},
14     raw::{Alignment, CategoryLabels, Measure, MissingValues, VarType},
15 };
16
17 pub type DictIndex = usize;
18
19 #[derive(Clone, Debug)]
20 pub struct Dictionary {
21     pub variables: IndexSet<ByIdentifier<Variable>>,
22     pub split_file: Vec<DictIndex>,
23     pub weight: Option<DictIndex>,
24     pub filter: Option<DictIndex>,
25     pub case_limit: Option<u64>,
26     pub file_label: Option<String>,
27     pub documents: Vec<String>,
28     pub vectors: HashSet<ByIdentifier<Vector>>,
29     pub attributes: HashSet<ByIdentifier<Attribute>>,
30     pub mrsets: HashSet<ByIdentifier<MultipleResponseSet>>,
31     pub variable_sets: HashSet<ByIdentifier<VariableSet>>,
32     pub encoding: &'static Encoding,
33 }
34
35 impl Dictionary {
36     pub fn new(encoding: &'static Encoding) -> Self {
37         Self {
38             variables: IndexSet::new(),
39             split_file: Vec::new(),
40             weight: None,
41             filter: None,
42             case_limit: None,
43             file_label: None,
44             documents: Vec::new(),
45             vectors: HashSet::new(),
46             attributes: HashSet::new(),
47             mrsets: HashSet::new(),
48             variable_sets: HashSet::new(),
49             encoding,
50         }
51     }
52
53     pub fn add_var(&mut self, variable: Variable) -> Result<(), ()> {
54         if self.variables.insert(ByIdentifier::new(variable)) {
55             Ok(())
56         } else {
57             Err(())
58         }
59     }
60
61     pub fn reorder_var(&mut self, from_index: DictIndex, to_index: DictIndex) {
62         if from_index != to_index {
63             self.variables.move_index(from_index, to_index);
64             self.update_dict_indexes(&|index| {
65                 if index == from_index {
66                     Some(to_index)
67                 } else if from_index < to_index {
68                     if index > from_index && index <= to_index {
69                         Some(index - 1)
70                     } else {
71                         Some(index)
72                     }
73                 } else {
74                     if index >= to_index && index < from_index {
75                         Some(index + 1)
76                     } else {
77                         Some(index)
78                     }
79                 }
80             })
81         }
82     }
83
84     pub fn retain_vars<F>(&mut self, keep: F)
85     where
86         F: Fn(&Variable) -> bool,
87     {
88         let mut deleted = Vec::new();
89         let mut index = 0;
90         self.variables.retain(|var_by_id| {
91             let keep = keep(&var_by_id.0);
92             if !keep {
93                 deleted.push(index);
94             }
95             index += 1;
96             keep
97         });
98         if !deleted.is_empty() {
99             self.update_dict_indexes(&|index| match deleted.binary_search(&index) {
100                 Ok(_) => None,
101                 Err(position) => Some(position),
102             })
103         }
104     }
105
106     pub fn delete_vars<R>(&mut self, range: R)
107     where
108         R: RangeBounds<DictIndex>,
109     {
110         let start = match range.start_bound() {
111             Bound::Included(&start) => start,
112             Bound::Excluded(&start) => start + 1,
113             Bound::Unbounded => 0,
114         };
115         let end = match range.end_bound() {
116             Bound::Included(&end) => end + 1,
117             Bound::Excluded(&end) => end,
118             Bound::Unbounded => self.variables.len(),
119         };
120         if end > start {
121             self.variables.drain(start..end);
122             self.update_dict_indexes(&|index| {
123                 if index < start {
124                     Some(index)
125                 } else if index < end {
126                     None
127                 } else {
128                     Some(index - end - start)
129                 }
130             })
131         }
132     }
133
134     fn update_dict_indexes<F>(&mut self, f: &F)
135     where
136         F: Fn(DictIndex) -> Option<DictIndex>,
137     {
138         update_dict_index_vec(&mut self.split_file, f);
139         self.weight = self.weight.map(|index| f(index)).flatten();
140         self.filter = self.filter.map(|index| f(index)).flatten();
141         self.vectors = self
142             .vectors
143             .drain()
144             .filter_map(|vector_by_id| {
145                 vector_by_id
146                     .0
147                     .with_updated_dict_indexes(f)
148                     .map(|vector| ByIdentifier::new(vector))
149             })
150             .collect();
151         self.mrsets = self
152             .mrsets
153             .drain()
154             .filter_map(|mrset_by_id| {
155                 mrset_by_id
156                     .0
157                     .with_updated_dict_indexes(f)
158                     .map(|mrset| ByIdentifier::new(mrset))
159             })
160             .collect();
161         self.variable_sets = self
162             .variable_sets
163             .drain()
164             .filter_map(|var_set_by_id| {
165                 var_set_by_id
166                     .0
167                     .with_updated_dict_indexes(f)
168                     .map(|var_set| ByIdentifier::new(var_set))
169             })
170             .collect();
171     }
172 }
173
174 fn update_dict_index_vec<F>(dict_indexes: &mut Vec<DictIndex>, f: F)
175 where
176     F: Fn(DictIndex) -> Option<DictIndex>,
177 {
178     dict_indexes.retain_mut(|index| {
179         if let Some(new) = f(*index) {
180             *index = new;
181             true
182         } else {
183             false
184         }
185     });
186 }
187
188 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
189 pub enum Role {
190     Input,
191     Target,
192     Both,
193     None,
194     Partition,
195     Split,
196 }
197
198 impl Default for Role {
199     fn default() -> Self {
200         Self::Input
201     }
202 }
203
204 pub enum DictClass {
205     Ordinary,
206     System,
207     Scratch,
208 }
209
210 impl DictClass {
211     pub fn from_identifier(id: &Identifier) -> Self {
212         if id.0.starts_with('$') {
213             Self::System
214         } else if id.0.starts_with('#') {
215             Self::Scratch
216         } else {
217             Self::Ordinary
218         }
219     }
220
221     pub fn must_leave(self) -> bool {
222         match self {
223             DictClass::Ordinary => false,
224             DictClass::System => false,
225             DictClass::Scratch => true,
226         }
227     }
228 }
229
230 #[derive(Clone, Debug)]
231 pub struct Variable {
232     pub name: Identifier,
233     pub width: VarWidth,
234     pub missing_values: MissingValues,
235     pub print_format: Spec,
236     pub write_format: Spec,
237     pub value_labels: HashMap<Value, String>,
238     pub label: Option<String>,
239     pub measure: Option<Measure>,
240     pub role: Role,
241     pub display_width: u32,
242     pub alignment: Alignment,
243     pub leave: bool,
244     pub short_names: Vec<Identifier>,
245     pub attributes: HashSet<ByIdentifier<Attribute>>,
246 }
247
248 impl Variable {
249     pub fn new(name: Identifier, width: VarWidth) -> Self {
250         let var_type = VarType::from_width(width);
251         let leave = DictClass::from_identifier(&name).must_leave();
252         Self {
253             name,
254             width,
255             missing_values: MissingValues::default(),
256             print_format: Spec::default_for_width(width),
257             write_format: Spec::default_for_width(width),
258             value_labels: HashMap::new(),
259             label: None,
260             measure: Measure::default_for_type(var_type),
261             role: Role::default(),
262             display_width: width.default_display_width(),
263             alignment: Alignment::default_for_type(var_type),
264             leave,
265             short_names: Vec::new(),
266             attributes: HashSet::new()
267         }
268     }
269 }
270
271 impl HasIdentifier for Variable {
272     fn identifier(&self) -> &Identifier {
273         &self.name
274     }
275 }
276
277 #[derive(Clone, Debug)]
278 pub struct Vector {
279     pub name: Identifier,
280     pub variables: Vec<DictIndex>,
281 }
282
283 impl Vector {
284     fn with_updated_dict_indexes(
285         mut self,
286         f: impl Fn(DictIndex) -> Option<DictIndex>,
287     ) -> Option<Self> {
288         update_dict_index_vec(&mut self.variables, f);
289         (!self.variables.is_empty()).then_some(self)
290     }
291 }
292
293 impl HasIdentifier for Vector {
294     fn identifier(&self) -> &Identifier {
295         &self.name
296     }
297 }
298
299 #[derive(Clone, Debug)]
300 pub struct Attribute {
301     pub name: Identifier,
302     pub values: Vec<String>,
303 }
304
305 impl HasIdentifier for Attribute {
306     fn identifier(&self) -> &Identifier {
307         &self.name
308     }
309 }
310
311 #[derive(Clone, Debug)]
312 pub struct MultipleResponseSet {
313     pub name: Identifier,
314     pub label: String,
315     pub mr_type: MultipleResponseType,
316     pub variables: Vec<DictIndex>,
317 }
318
319 impl MultipleResponseSet {
320     fn with_updated_dict_indexes(
321         mut self,
322         f: impl Fn(DictIndex) -> Option<DictIndex>,
323     ) -> Option<Self> {
324         update_dict_index_vec(&mut self.variables, f);
325         (self.variables.len() > 1).then_some(self)
326     }
327 }
328
329 impl HasIdentifier for MultipleResponseSet {
330     fn identifier(&self) -> &Identifier {
331         &self.name
332     }
333 }
334
335 #[derive(Clone, Debug)]
336 pub enum MultipleResponseType {
337     MultipleDichotomy {
338         value: Value,
339         labels: CategoryLabels,
340     },
341     MultipleCategory,
342 }
343
344 #[derive(Clone, Debug)]
345 pub struct VariableSet {
346     pub name: Identifier,
347     pub variables: Vec<DictIndex>,
348 }
349
350 impl VariableSet {
351     fn with_updated_dict_indexes(
352         mut self,
353         f: impl Fn(DictIndex) -> Option<DictIndex>,
354     ) -> Option<Self> {
355         update_dict_index_vec(&mut self.variables, f);
356         (!self.variables.is_empty()).then_some(self)
357     }
358 }
359
360 impl HasIdentifier for VariableSet {
361     fn identifier(&self) -> &Identifier {
362         &self.name
363     }
364 }
365
366 #[cfg(test)]
367 mod test {
368     use std::collections::HashSet;
369
370     use crate::identifier::Identifier;
371
372     use super::{ByIdentifier, HasIdentifier};
373
374     #[derive(PartialEq, Eq, Debug, Clone)]
375     struct Variable {
376         name: Identifier,
377         value: i32,
378     }
379
380     impl HasIdentifier for Variable {
381         fn identifier(&self) -> &Identifier {
382             &self.name
383         }
384     }
385
386     #[test]
387     fn test() {
388         // Variables should not be the same if their values differ.
389         let abcd = Identifier::new_utf8("abcd").unwrap();
390         let abcd1 = Variable {
391             name: abcd.clone(),
392             value: 1,
393         };
394         let abcd2 = Variable {
395             name: abcd,
396             value: 2,
397         };
398         assert_ne!(abcd1, abcd2);
399
400         // But `ByName` should treat them the same.
401         let abcd1_by_name = ByIdentifier::new(abcd1);
402         let abcd2_by_name = ByIdentifier::new(abcd2);
403         assert_eq!(abcd1_by_name, abcd2_by_name);
404
405         // And a `HashSet` of `ByName` should also treat them the same.
406         let mut vars: HashSet<ByIdentifier<Variable>> = HashSet::new();
407         assert!(vars.insert(ByIdentifier::new(abcd1_by_name.0.clone())));
408         assert!(!vars.insert(ByIdentifier::new(abcd2_by_name.0.clone())));
409         assert_eq!(
410             vars.get(&Identifier::new_utf8("abcd").unwrap())
411                 .unwrap()
412                 .0
413                 .value,
414             1
415         );
416     }
417 }