progress! and clippy!
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 30 Jan 2024 05:17:49 +0000 (21:17 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Tue, 30 Jan 2024 05:17:49 +0000 (21:17 -0800)
rust/build.rs
rust/src/cooked.rs
rust/src/dictionary.rs
rust/src/main.rs
rust/src/raw.rs

index baaacc5d4f9edcee5a46c27448a88b1298cffbae..f8cb9efa13a7439764c0773ff4a77caa6c30dbfa 100644 (file)
@@ -83,7 +83,7 @@ fn process_converter<'a>(
     for (source, number) in cps {
         codepages
             .entry(number)
-            .or_insert_with(BTreeMap::new)
+            .or_default()
             .insert(source, all.clone());
     }
 }
@@ -127,9 +127,9 @@ lazy_static! {
             for name in value2.iter().map(|name| name.to_ascii_lowercase()) {
                 names
                     .entry(name)
-                    .or_insert_with(BTreeMap::new)
+                    .or_default()
                     .entry(source)
-                    .or_insert_with(Vec::new)
+                    .or_default()
                     .push(cpnumber);
             }
         }
index 894204b6b386f44b280f8070480a83ad76c147f3..1cb878f120d6768cff919e86c60c368a9c8347a5 100644 (file)
@@ -283,31 +283,6 @@ impl<'a> Headers<'a> {
     }
 }
 
-pub fn encoding_from_headers(
-    headers: &Vec<raw::Record>,
-    warn: &impl Fn(Error),
-) -> Result<&'static Encoding, Error> {
-    let mut encoding_record = None;
-    let mut integer_info_record = None;
-    for record in headers {
-        match record {
-            raw::Record::Encoding(record) => encoding_record = Some(record),
-            raw::Record::IntegerInfo(record) => integer_info_record = Some(record),
-            _ => (),
-        }
-    }
-    let encoding = encoding_record.map(|record| record.0.as_str());
-    let character_code = integer_info_record.map(|record| record.character_code);
-    match get_encoding(encoding, character_code) {
-        Ok(encoding) => Ok(encoding),
-        Err(err @ EncodingError::Ebcdic) => Err(Error::EncodingError(err)),
-        Err(err) => {
-            warn(Error::EncodingError(err));
-            // Warn that we're using the default encoding.
-            Ok(default_encoding())
-        }
-    }
-}
 
 pub fn decode(
     headers: Vec<raw::Record>,
index ace97b03f7db5690d2a0ff565d4b0f2ced0ee242..59e1e3a853dcc59191d75d61d650dce27b37be5f 100644 (file)
@@ -1,7 +1,8 @@
 use std::{
+    cmp::Ordering,
     collections::{HashMap, HashSet},
     fmt::Debug,
-    ops::{Bound, RangeBounds}, cmp::Ordering,
+    ops::{Bound, RangeBounds},
 };
 
 use encoding_rs::Encoding;
@@ -12,7 +13,7 @@ use ordered_float::OrderedFloat;
 use crate::{
     format::Spec,
     identifier::{ByIdentifier, HasIdentifier, Identifier},
-    raw::{Alignment, CategoryLabels, Measure, MissingValues, VarType, self, RawStr, Decoder},
+    raw::{self, Alignment, CategoryLabels, Decoder, Measure, MissingValues, RawStr, VarType},
 };
 
 pub type DictIndex = usize;
@@ -119,6 +120,8 @@ pub struct Dictionary {
     pub encoding: &'static Encoding,
 }
 
+pub struct DuplicateVariableName;
+
 impl Dictionary {
     pub fn new(encoding: &'static Encoding) -> Self {
         Self {
@@ -137,11 +140,11 @@ impl Dictionary {
         }
     }
 
-    pub fn add_var(&mut self, variable: Variable) -> Result<(), ()> {
+    pub fn add_var(&mut self, variable: Variable) -> Result<(), DuplicateVariableName> {
         if self.variables.insert(ByIdentifier::new(variable)) {
             Ok(())
         } else {
-            Err(())
+            Err(DuplicateVariableName)
         }
     }
 
@@ -149,6 +152,7 @@ impl Dictionary {
         if from_index != to_index {
             self.variables.move_index(from_index, to_index);
             self.update_dict_indexes(&|index| {
+                #[allow(clippy::collapsible_else_if)]
                 if index == from_index {
                     Some(to_index)
                 } else if from_index < to_index {
@@ -223,8 +227,8 @@ impl Dictionary {
         F: Fn(DictIndex) -> Option<DictIndex>,
     {
         update_dict_index_vec(&mut self.split_file, f);
-        self.weight = self.weight.map(|index| f(index)).flatten();
-        self.filter = self.filter.map(|index| f(index)).flatten();
+        self.weight = self.weight.and_then(f);
+        self.filter = self.filter.and_then(f);
         self.vectors = self
             .vectors
             .drain()
@@ -232,7 +236,7 @@ impl Dictionary {
                 vector_by_id
                     .0
                     .with_updated_dict_indexes(f)
-                    .map(|vector| ByIdentifier::new(vector))
+                    .map(ByIdentifier::new)
             })
             .collect();
         self.mrsets = self
@@ -242,7 +246,7 @@ impl Dictionary {
                 mrset_by_id
                     .0
                     .with_updated_dict_indexes(f)
-                    .map(|mrset| ByIdentifier::new(mrset))
+                    .map(ByIdentifier::new)
             })
             .collect();
         self.variable_sets = self
@@ -252,7 +256,7 @@ impl Dictionary {
                 var_set_by_id
                     .0
                     .with_updated_dict_indexes(f)
-                    .map(|var_set| ByIdentifier::new(var_set))
+                    .map(ByIdentifier::new)
             })
             .collect();
     }
@@ -350,7 +354,7 @@ impl Variable {
             alignment: Alignment::default_for_type(var_type),
             leave,
             short_names: Vec::new(),
-            attributes: HashSet::new()
+            attributes: HashSet::new(),
         }
     }
 }
index dece725c65073991a9c08e3570f54f750d1666e7..a3b3145bedff7c4edbf34bb2fc3440da5c787f43 100644 (file)
@@ -17,7 +17,7 @@
 use anyhow::Result;
 use clap::{Parser, ValueEnum};
 use encoding_rs::Encoding;
-use pspp::raw::{Magic, Reader, Record};
+use pspp::raw::{encoding_from_headers, Decoder, Magic, Reader, Record};
 use std::fs::File;
 use std::io::BufReader;
 use std::path::{Path, PathBuf};
@@ -60,6 +60,7 @@ fn parse_encoding(arg: &str) -> Result<&'static Encoding, UnknownEncodingError>
 enum Mode {
     Identify,
     Raw,
+    Decoded,
     #[default]
     Cooked,
 }
@@ -115,11 +116,29 @@ fn dissect(
                 }
             }
         }
-/*
         Mode::Decoded => {
             let headers: Vec<Record> = reader.collect::<Result<Vec<_>, _>>()?;
+            let encoding = match encoding {
+                Some(encoding) => encoding,
+                None => encoding_from_headers(&headers, &|e| eprintln!("{e}"))?,
+            };
+            let decoder = Decoder::new(encoding, |e| eprintln!("{e}"));
+            for header in headers {
+                let header = header.decode(&decoder);
+                println!("{:?}", header);
+                /*
+                                if let Record::Cases(cases) = header {
+                                    let mut cases = cases.borrow_mut();
+                                    for _ in 0..max_cases {
+                                        let Some(Ok(record)) = cases.next() else {
+                                            break;
+                                        };
+                                        println!("{:?}", record);
+                                    }
+                                }
+                */
+            }
         }
-*/
         Mode::Cooked => {
             /*
                 let headers: Vec<Record> = reader.collect::<Result<Vec<_>, _>>()?;
index a38422950ad9fc2d1c8d971756a9c33829b89cb1..1a760a0517dc26168b6084be90266fb7227f8007 100644 (file)
@@ -1,5 +1,6 @@
 use crate::{
     dictionary::VarWidth,
+    encoding::{default_encoding, get_encoding, Error as EncodingError},
     endian::{Endian, Parse, ToBytes},
     identifier::{Error as IdError, Identifier},
 };
@@ -105,6 +106,9 @@ pub enum Error {
         expected_n_blocks: u64,
         ztrailer_len: u64,
     },
+
+    #[error("{0}")]
+    EncodingError(EncodingError),
 }
 
 #[derive(ThisError, Debug)]
@@ -195,6 +199,9 @@ pub enum Warning {
     #[error("Invalid variable name in long string value label record.  {0}")]
     InvalidLongStringValueLabelName(IdError),
 
+    #[error("{0}")]
+    EncodingError(EncodingError),
+
     #[error("Details TBD")]
     TBD,
 }
@@ -227,6 +234,7 @@ pub enum Record {
     Cases(Rc<RefCell<Cases>>),
 }
 
+#[derive(Clone, Debug)]
 pub enum DecodedRecord<'a> {
     Header(HeaderRecord<Cow<'a, str>>),
     Variable(VariableRecord<Cow<'a, str>, String>),
@@ -257,7 +265,7 @@ impl Record {
         reader: &mut R,
         endian: Endian,
         var_types: &[VarType],
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error>
     where
         R: Read + Seek,
@@ -298,7 +306,7 @@ impl Record {
             Record::NumberOfCases(record) => DecodedRecord::NumberOfCases(record.clone()),
             Record::Text(record) => record.decode(decoder),
             Record::OtherExtension(record) => DecodedRecord::OtherExtension(record.clone()),
-            Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(record.clone()),
+            Record::EndOfHeaders(record) => DecodedRecord::EndOfHeaders(*record),
             Record::ZHeader(record) => DecodedRecord::ZHeader(record.clone()),
             Record::ZTrailer(record) => DecodedRecord::ZTrailer(record.clone()),
             Record::Cases(_) => todo!(),
@@ -306,6 +314,32 @@ impl Record {
     }
 }
 
+pub fn encoding_from_headers(
+    headers: &Vec<Record>,
+    warn: &impl Fn(Warning),
+) -> Result<&'static Encoding, Error> {
+    let mut encoding_record = None;
+    let mut integer_info_record = None;
+    for record in headers {
+        match record {
+            Record::Encoding(record) => encoding_record = Some(record),
+            Record::IntegerInfo(record) => integer_info_record = Some(record),
+            _ => (),
+        }
+    }
+    let encoding = encoding_record.map(|record| record.0.as_str());
+    let character_code = integer_info_record.map(|record| record.character_code);
+    match get_encoding(encoding, character_code) {
+        Ok(encoding) => Ok(encoding),
+        Err(err @ EncodingError::Ebcdic) => Err(Error::EncodingError(err)),
+        Err(err) => {
+            warn(Warning::EncodingError(err));
+            // Warn that we're using the default encoding.
+            Ok(default_encoding())
+        }
+    }
+}
+
 // If `s` is valid UTF-8, returns it decoded as UTF-8, otherwise returns it
 // decoded as Latin-1 (actually bytes interpreted as Unicode code points).
 fn default_decode(s: &[u8]) -> Cow<str> {
@@ -489,6 +523,15 @@ pub struct Decoder {
 }
 
 impl Decoder {
+    pub fn new<F>(encoding: &'static Encoding, warn: F) -> Self
+    where
+        F: Fn(Warning) + 'static,
+    {
+        Self {
+            encoding,
+            warn: Box::new(warn),
+        }
+    }
     fn warn(&self, warning: Warning) {
         (self.warn)(warning)
     }
@@ -881,15 +924,7 @@ where
             &self.header,
         )
     }
-}
-
-impl<R> Iterator for Reader<R>
-where
-    R: Read + Seek + 'static,
-{
-    type Item = Result<Record, Error>;
-
-    fn next(&mut self) -> Option<Self::Item> {
+    fn _next(&mut self) -> Option<<Self as Iterator>::Item> {
         match self.state {
             ReaderState::Start => {
                 self.state = ReaderState::Headers;
@@ -960,6 +995,21 @@ where
     }
 }
 
+impl<R> Iterator for Reader<R>
+where
+    R: Read + Seek + 'static,
+{
+    type Item = Result<Record, Error>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let retval = self._next();
+        if matches!(retval, Some(Err(_))) {
+            self.state = ReaderState::End;
+        }
+        retval
+    }
+}
+
 trait ReadSeek: Read + Seek {}
 impl<T> ReadSeek for T where T: Read + Seek {}
 
@@ -1176,7 +1226,7 @@ impl MissingValues<RawStr<8>> {
         };
         Ok(Self { values, range })
     }
-    fn decode<'a>(&'a self, decoder: &Decoder) -> MissingValues<String> {
+    fn decode(&self, decoder: &Decoder) -> MissingValues<String> {
         MissingValues {
             values: self
                 .values
@@ -1291,7 +1341,7 @@ impl VariableRecord<RawString, RawStr<8>> {
         }))
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         DecodedRecord::Variable(VariableRecord {
             offsets: self.offsets.clone(),
             width: self.width,
@@ -1440,7 +1490,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
         r: &mut R,
         endian: Endian,
         var_types: &[VarType],
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error> {
         let label_offset = r.stream_position()?;
         let n: u32 = endian.parse(read_bytes(r)?);
@@ -1547,7 +1597,7 @@ impl ValueLabelRecord<RawStr<8>, RawString> {
             .labels
             .iter()
             .map(|ValueLabel { value, label }| ValueLabel {
-                value: value.clone(),
+                value: *value,
                 label: decoder.decode(label),
             })
             .collect();
@@ -1605,7 +1655,7 @@ impl DocumentRecord<RawDocumentLine> {
         }
     }
 
-    pub fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         DecodedRecord::Document(DocumentRecord {
             offsets: self.offsets.clone(),
             lines: self
@@ -1808,7 +1858,7 @@ impl MultipleResponseSet<RawString, RawString> {
         for short_name in self.short_names.iter() {
             if let Some(short_name) = decoder
                 .decode_identifier(short_name)
-                .map_err(|err| Warning::InvalidMrSetName(err))
+                .map_err(Warning::InvalidMrSetName)
                 .issue_warning(&decoder.warn)
             {
                 short_names.push(short_name);
@@ -1817,10 +1867,10 @@ impl MultipleResponseSet<RawString, RawString> {
         Ok(MultipleResponseSet {
             name: decoder
                 .decode_identifier(&self.name)
-                .map_err(|err| Warning::InvalidMrSetVariableName(err))?,
+                .map_err(Warning::InvalidMrSetVariableName)?,
             label: decoder.decode(&self.label),
             mr_type: self.mr_type.clone(),
-            short_names: short_names,
+            short_names,
         })
     }
 }
@@ -1852,7 +1902,7 @@ impl ExtensionRecord for MultipleResponseRecord<RawString, RawString> {
 }
 
 impl MultipleResponseRecord<RawString, RawString> {
-    fn decode<'a>(&'a self, decoder: &Decoder) -> DecodedRecord {
+    fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         let mut sets = Vec::new();
         for set in self.0.iter() {
             if let Some(set) = set.decode(decoder).issue_warning(&decoder.warn) {
@@ -1952,7 +2002,7 @@ impl VarDisplayRecord {
         ext: &Extension,
         n_vars: usize,
         endian: Endian,
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Record, Warning> {
         if ext.size != 4 {
             return Err(Warning::BadRecordSize {
@@ -2005,7 +2055,7 @@ where
 }
 
 impl LongStringMissingValues<RawString, RawStr<8>> {
-    fn decode<'a>(
+    fn decode(
         &self,
         decoder: &Decoder,
     ) -> Result<LongStringMissingValues<Identifier, String>, IdError> {
@@ -2075,15 +2125,12 @@ impl ExtensionRecord for LongStringMissingValueRecord<RawString, RawStr<8>> {
 }
 
 impl LongStringMissingValueRecord<RawString, RawStr<8>> {
-    pub fn decode<'a>(
-        &self,
-        decoder: &Decoder,
-    ) -> LongStringMissingValueRecord<Identifier, String> {
+    pub fn decode(&self, decoder: &Decoder) -> LongStringMissingValueRecord<Identifier, String> {
         let mut mvs = Vec::with_capacity(self.0.len());
         for mv in self.0.iter() {
             if let Some(mv) = mv
                 .decode(decoder)
-                .map_err(|err| Warning::InvalidLongStringMissingValueVariableName(err))
+                .map_err(Warning::InvalidLongStringMissingValueVariableName)
                 .issue_warning(&decoder.warn)
             {
                 mvs.push(mv);
@@ -2113,7 +2160,7 @@ impl ExtensionRecord for EncodingRecord {
     }
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Clone, Debug)]
 pub struct NumberOfCasesRecord {
     /// Always observed as 1.
     pub one: u64,
@@ -2168,7 +2215,7 @@ impl TextRecord {
             text: extension.data.into(),
         }
     }
-    pub fn decode<'a>(&self, decoder: &Decoder) -> DecodedRecord {
+    pub fn decode(&self, decoder: &Decoder) -> DecodedRecord {
         match self.rec_type {
             TextRecordType::VariableSets => {
                 DecodedRecord::VariableSets(VariableSetRecord::decode(self, decoder))
@@ -2268,7 +2315,7 @@ impl Attribute {
     }
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Default)]
 pub struct AttributeSet(pub HashMap<Identifier, Vec<String>>);
 
 impl AttributeSet {
@@ -2294,13 +2341,7 @@ impl AttributeSet {
     }
 }
 
-impl Default for AttributeSet {
-    fn default() -> Self {
-        Self(HashMap::default())
-    }
-}
-
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Default)]
 pub struct FileAttributeRecord(AttributeSet);
 
 impl FileAttributeRecord {
@@ -2318,12 +2359,6 @@ impl FileAttributeRecord {
     }
 }
 
-impl Default for FileAttributeRecord {
-    fn default() -> Self {
-        Self(AttributeSet::default())
-    }
-}
-
 #[derive(Clone, Debug)]
 pub struct VarAttributeSet {
     pub long_var_name: Identifier,
@@ -2357,12 +2392,12 @@ impl VariableAttributeRecord {
         let mut var_attribute_sets = Vec::new();
         while !input.is_empty() {
             let Some((var_attribute, rest)) =
-                VarAttributeSet::parse(decoder, &input).issue_warning(&decoder.warn)
+                VarAttributeSet::parse(decoder, input).issue_warning(&decoder.warn)
             else {
                 break;
             };
             var_attribute_sets.push(var_attribute);
-            input = rest.into();
+            input = rest;
         }
         VariableAttributeRecord(var_attribute_sets)
     }
@@ -2530,7 +2565,7 @@ impl Extension {
         r: &mut R,
         endian: Endian,
         n_vars: usize,
-        warn: &Box<dyn Fn(Warning)>,
+        warn: &dyn Fn(Warning),
     ) -> Result<Option<Record>, Error> {
         let subtype = endian.parse(read_bytes(r)?);
         let header_offset = r.stream_position()?;
@@ -2772,7 +2807,7 @@ impl LongStringValueLabels<RawString, RawString> {
         let mut labels = Vec::with_capacity(self.labels.len());
         for (value, label) in self.labels.iter() {
             let value = decoder.decode_exact_length(&value.0);
-            let label = decoder.decode(&label);
+            let label = decoder.decode(label);
             labels.push((value, label));
         }