1e3b4fff1189e5bd96526c9442c21ddf8571f9ea
[pspp] / rust / src / sack.rs
1 use float_next_after::NextAfter;
2 use num::{Bounded, Zero};
3 use ordered_float::OrderedFloat;
4 use std::{
5     collections::{hash_map::Entry, HashMap},
6     error::Error as StdError,
7     fmt::{Display, Formatter, Result as FmtResult},
8     iter::{repeat, Peekable},
9     str::Chars,
10 };
11
12 use crate::endian::{Endian, ToBytes};
13
14 pub type Result<T, F = Error> = std::result::Result<T, F>;
15
16 #[derive(Debug)]
17 pub struct Error {
18     pub file_name: Option<String>,
19     pub line_number: Option<usize>,
20     pub message: String,
21 }
22
23 impl Error {
24     fn new(file_name: Option<&str>, line_number: Option<usize>, message: String) -> Error {
25         Error {
26             file_name: file_name.map(String::from),
27             line_number,
28             message,
29         }
30     }
31 }
32
33 impl StdError for Error {}
34
35 impl Display for Error {
36     fn fmt(&self, f: &mut Formatter) -> FmtResult {
37         if let Some(ref file_name) = self.file_name {
38             write!(f, "{file_name}:")?;
39             if self.line_number.is_none() {
40                 write!(f, " ")?;
41             }
42         }
43         if let Some(line_number) = self.line_number {
44             write!(f, "{line_number}: ")?;
45         }
46         write!(f, "{}", self.message)
47     }
48 }
49
50 pub fn sack(input: &str, input_file_name: Option<&str>, endian: Endian) -> Result<Vec<u8>> {
51     let mut symbol_table = HashMap::new();
52     let output = _sack(input, input_file_name, endian, &mut symbol_table)?;
53     let output = if !symbol_table.is_empty() {
54         for (k, v) in symbol_table.iter() {
55             if v.is_none() {
56                 Err(Error::new(
57                     input_file_name,
58                     None,
59                     format!("label {k} used but never defined"),
60                 ))?
61             }
62         }
63         _sack(input, input_file_name, endian, &mut symbol_table)?
64     } else {
65         output
66     };
67     Ok(output)
68 }
69
70 fn _sack(
71     input: &str,
72     input_file_name: Option<&str>,
73     endian: Endian,
74     symbol_table: &mut HashMap<String, Option<u32>>,
75 ) -> Result<Vec<u8>> {
76     let mut lexer = Lexer::new(input, input_file_name, endian)?;
77     let mut output = Vec::new();
78     while parse_data_item(&mut lexer, &mut output, symbol_table)? {}
79     Ok(output)
80 }
81
82 fn parse_data_item(
83     lexer: &mut Lexer,
84     output: &mut Vec<u8>,
85     symbol_table: &mut HashMap<String, Option<u32>>,
86 ) -> Result<bool> {
87     if lexer.token.is_none() {
88         return Ok(false);
89     };
90
91     let initial_len = output.len();
92     match lexer.take()? {
93         Token::Integer(integer) => output.extend_from_slice(&lexer.endian.to_bytes(integer)),
94         Token::Float(float) => output.extend_from_slice(&lexer.endian.to_bytes(float.0)),
95         Token::PcSysmis => {
96             output.extend_from_slice(&[0xf5, 0x1e, 0x26, 0x02, 0x8a, 0x8c, 0xed, 0xff])
97         }
98         Token::I8 => put_integers::<u8, 1>(lexer, "i8", output)?,
99         Token::I16 => put_integers::<u16, 2>(lexer, "i16", output)?,
100         Token::I64 => put_integers::<i64, 8>(lexer, "i64", output)?,
101         Token::String(string) => output.extend_from_slice(string.as_bytes()),
102         Token::S(size) => {
103             let Some(Token::String(ref string)) = lexer.token else {
104                 Err(lexer.error(format!("string expected after 's{size}'")))?
105             };
106             let len = string.len();
107             if len > size {
108                 Err(lexer.error(format!(
109                     "{len}-byte string is longer than pad length {size}"
110                 )))?
111             }
112             output.extend_from_slice(string.as_bytes());
113             output.extend(repeat(b' ').take(size - len));
114             lexer.get()?;
115         }
116         Token::LParen => {
117             while lexer.token != Some(Token::RParen) {
118                 parse_data_item(lexer, output, symbol_table)?;
119             }
120             lexer.get()?;
121         }
122         Token::Count => put_counted_items::<u32, 4>(lexer, "COUNT", output, symbol_table)?,
123         Token::Count8 => put_counted_items::<u8, 1>(lexer, "COUNT8", output, symbol_table)?,
124         Token::Hex => {
125             let Some(Token::String(ref string)) = lexer.token else {
126                 Err(lexer.error(String::from("string expected after 'hex'")))?
127             };
128             let mut i = string.chars();
129             loop {
130                 let Some(c0) = i.next() else { return Ok(true) };
131                 let Some(c1) = i.next() else {
132                     Err(lexer.error(String::from("hex string has odd number of characters")))?
133                 };
134                 let (Some(digit0), Some(digit1)) = (c0.to_digit(16), c1.to_digit(16)) else {
135                     Err(lexer.error(String::from("invalid digit in hex string")))?
136                 };
137                 let byte = digit0 * 16 + digit1;
138                 output.push(byte as u8);
139             }
140         }
141         Token::Label(name) => {
142             let value = output.len() as u32;
143             match symbol_table.entry(name) {
144                 Entry::Vacant(v) => {
145                     v.insert(Some(value));
146                 }
147                 Entry::Occupied(o) => {
148                     if let Some(v) = o.get() {
149                         if *v != value {
150                             Err(lexer.error(String::from("syntax error")))?
151                         }
152                     }
153                 }
154             };
155         }
156         Token::At(name) => {
157             let mut value = symbol_table.entry(name).or_insert(None).unwrap_or(0);
158             lexer.get()?;
159             loop {
160                 let plus = match lexer.token {
161                     Some(Token::Plus) => true,
162                     Some(Token::Minus) => false,
163                     _ => break,
164                 };
165                 lexer.get()?;
166
167                 let operand = match lexer.token {
168                     Some(Token::At(ref name)) => if let Some(value) = symbol_table.get(name) {
169                         *value
170                     } else {
171                         symbol_table.insert(name.clone(), None);
172                         None
173                     }
174                     .unwrap_or(0),
175                     Some(Token::Integer(integer)) => integer
176                         .try_into()
177                         .map_err(|msg| lexer.error(format!("bad offset literal ({msg})")))?,
178                     _ => Err(lexer.error(String::from("expecting @label or integer literal")))?,
179                 };
180                 lexer.get()?;
181
182                 value = if plus {
183                     value.checked_add(operand)
184                 } else {
185                     value.checked_sub(operand)
186                 }
187                 .ok_or_else(|| lexer.error(String::from("overflow in offset arithmetic")))?;
188             }
189             output.extend_from_slice(&lexer.endian.to_bytes(value));
190         }
191         _ => (),
192     };
193     if lexer.token == Some(Token::Asterisk) {
194         lexer.get()?;
195         let Token::Integer(count) = lexer.take()? else {
196             Err(lexer.error(String::from("positive integer expected after '*'")))?
197         };
198         if count < 1 {
199             Err(lexer.error(String::from("positive integer expected after '*'")))?
200         };
201         let final_len = output.len();
202         for _ in 1..count {
203             output.extend_from_within(initial_len..final_len);
204         }
205     }
206     match lexer.token {
207         Some(Token::Semicolon) => {
208             lexer.get()?;
209         }
210         Some(Token::RParen) => (),
211         _ => Err(lexer.error(String::from("';' expected")))?,
212     }
213     Ok(true)
214 }
215
216 fn put_counted_items<T, const N: usize>(
217     lexer: &mut Lexer,
218     name: &str,
219     output: &mut Vec<u8>,
220     symbol_table: &mut HashMap<String, Option<u32>>,
221 ) -> Result<()>
222 where
223     T: Zero + TryFrom<usize>,
224     Endian: ToBytes<T, N>,
225 {
226     let old_size = output.len();
227     output.extend_from_slice(&lexer.endian.to_bytes(T::zero()));
228     if lexer.token != Some(Token::LParen) {
229         Err(lexer.error(format!("'(' expected after '{name}'")))?
230     }
231     lexer.get()?;
232     while lexer.token != Some(Token::RParen) {
233         parse_data_item(lexer, output, symbol_table)?;
234     }
235     lexer.get()?;
236     let delta = output.len() - old_size;
237     let Ok(delta): Result<T, _> = delta.try_into() else {
238         Err(lexer.error(format!("{delta} bytes is too much for '{name}'")))?
239     };
240     let dest = &mut output[old_size..old_size + N];
241     dest.copy_from_slice(&lexer.endian.to_bytes(delta));
242     Ok(())
243 }
244
245 fn put_integers<T, const N: usize>(
246     lexer: &mut Lexer,
247     name: &str,
248     output: &mut Vec<u8>,
249 ) -> Result<()>
250 where
251     T: Bounded + Display + TryFrom<i64> + Copy,
252     Endian: ToBytes<T, N>,
253 {
254     let mut n = 0;
255     while let Some(integer) = lexer.take_if(|t| match t {
256         Token::Integer(integer) => Some(*integer),
257         _ => None,
258     })? {
259         let Ok(integer) = integer.try_into() else {
260             Err(lexer.error(format!(
261                 "{integer} is not in the valid range [{},{}]",
262                 T::min_value(),
263                 T::max_value()
264             )))?
265         };
266         output.extend_from_slice(&lexer.endian.to_bytes(integer));
267         n += 1;
268     }
269     if n == 0 {
270         Err(lexer.error(format!("integer expected after '{name}'")))?
271     }
272     Ok(())
273 }
274
275 #[derive(PartialEq, Eq, Clone, Debug)]
276 enum Token {
277     Integer(i64),
278     Float(OrderedFloat<f64>),
279     PcSysmis,
280     String(String),
281     Semicolon,
282     Asterisk,
283     LParen,
284     RParen,
285     I8,
286     I16,
287     I64,
288     S(usize),
289     Count,
290     Count8,
291     Hex,
292     Label(String),
293     At(String),
294     Minus,
295     Plus,
296 }
297
298 struct Lexer<'a> {
299     iter: Peekable<Chars<'a>>,
300     token: Option<Token>,
301     input_file_name: Option<&'a str>,
302     line_number: usize,
303     endian: Endian,
304 }
305
306 impl<'a> Lexer<'a> {
307     fn new(input: &'a str, input_file_name: Option<&'a str>, endian: Endian) -> Result<Lexer<'a>> {
308         let mut lexer = Lexer {
309             iter: input.chars().peekable(),
310             token: None,
311             input_file_name,
312             line_number: 1,
313             endian,
314         };
315         lexer.token = lexer.next()?;
316         Ok(lexer)
317     }
318     fn error(&self, message: String) -> Error {
319         Error::new(self.input_file_name, Some(self.line_number), message)
320     }
321     fn take(&mut self) -> Result<Token> {
322         let Some(token) = self.token.take() else {
323             Err(self.error(String::from("unexpected end of input")))?
324         };
325         self.token = self.next()?;
326         Ok(token)
327     }
328     fn take_if<F, T>(&mut self, condition: F) -> Result<Option<T>>
329     where
330         F: FnOnce(&Token) -> Option<T>,
331     {
332         let Some(ref token) = self.token else {
333             return Ok(None);
334         };
335         match condition(token) {
336             Some(value) => {
337                 self.token = self.next()?;
338                 Ok(Some(value))
339             }
340             None => Ok(None),
341         }
342     }
343     fn get(&mut self) -> Result<Option<&Token>> {
344         if self.token.is_none() {
345             Err(self.error(String::from("unexpected end of input")))?
346         } else {
347             self.token = self.next()?;
348             Ok((&self.token).into())
349         }
350     }
351
352     fn next(&mut self) -> Result<Option<Token>> {
353         // Get the first character of the token, skipping past white space and
354         // comments.
355         let c = loop {
356             let Some(c) = self.iter.next() else {
357                 return Ok(None);
358             };
359             let c = if c == '#' {
360                 loop {
361                     match self.iter.next() {
362                         None => return Ok(None),
363                         Some('\n') => break,
364                         _ => (),
365                     }
366                 }
367                 '\n'
368             } else {
369                 c
370             };
371             if c == '\n' {
372                 self.line_number += 1
373             } else if !c.is_whitespace() && c != '<' && c != '>' {
374                 break c;
375             }
376         };
377
378         let token =
379             match c {
380                 c if c.is_ascii_digit() || c == '-' => {
381                     let mut s = String::from(c);
382                     while let Some(c) = self
383                         .iter
384                         .next_if(|&c| c.is_ascii_digit() || c.is_alphabetic() || c == '.')
385                     {
386                         s.push(c);
387                     }
388
389                     if s == "-" {
390                         Token::Minus
391                     } else if !s.contains('.') {
392                         Token::Integer(s.parse().map_err(|msg| {
393                             self.error(format!("bad integer literal '{s}' ({msg})"))
394                         })?)
395                     } else {
396                         Token::Float(s.parse().map_err(|msg| {
397                             self.error(format!("bad float literal '{s}' ({msg})"))
398                         })?)
399                     }
400                 }
401                 '"' => {
402                     let mut s = String::new();
403                     loop {
404                         match self.iter.next() {
405                             None => Err(self.error(String::from("end-of-file inside string")))?,
406                             Some('\n') => Err(self.error(String::from("new-line inside string")))?,
407                             Some('"') => break,
408                             Some(c) => s.push(c),
409                         }
410                     }
411                     Token::String(s)
412                 }
413                 ';' => Token::Semicolon,
414                 '*' => Token::Asterisk,
415                 '+' => Token::Plus,
416                 '(' => Token::LParen,
417                 ')' => Token::RParen,
418                 c if c.is_alphabetic() || c == '@' || c == '_' => {
419                     let mut s = String::from(c);
420                     while let Some(c) = self.iter.next_if(|&c| {
421                         c.is_ascii_digit() || c.is_alphabetic() || c == '.' || c == '_'
422                     }) {
423                         s.push(c);
424                     }
425                     if self.iter.next_if_eq(&':').is_some() {
426                         Token::Label(s)
427                     } else if s.starts_with('@') {
428                         Token::At(s)
429                     } else if let Some(count) = s.strip_prefix('s') {
430                         Token::S(count.parse().map_err(|msg| {
431                             self.error(format!("bad counted string '{s}' ({msg})"))
432                         })?)
433                     } else {
434                         match &s[..] {
435                             "i8" => Token::I8,
436                             "i16" => Token::I16,
437                             "i64" => Token::I64,
438                             "SYSMIS" => Token::Float(OrderedFloat(-f64::MAX)),
439                             "PCSYSMIS" => Token::PcSysmis,
440                             "LOWEST" => Token::Float((-f64::MAX).next_after(0.0).into()),
441                             "HIGHEST" => Token::Float(f64::MAX.into()),
442                             "ENDIAN" => {
443                                 Token::Integer(if self.endian == Endian::Big { 1 } else { 2 })
444                             }
445                             "COUNT" => Token::Count,
446                             "COUNT8" => Token::Count8,
447                             "hex" => Token::Hex,
448                             _ => Err(self.error(format!("invalid token '{s}'")))?,
449                         }
450                     }
451                 }
452                 _ => Err(self.error(format!("invalid input byte '{c}'")))?,
453             };
454         Ok(Some(token))
455     }
456 }
457
458 #[cfg(test)]
459 mod test {
460     use crate::endian::Endian;
461     use crate::sack::sack;
462     use anyhow::Result;
463     use hexplay::HexView;
464
465     #[test]
466     fn basic_sack() -> Result<()> {
467         let input = r#"
468 "$FL2"; s60 "$(#) SPSS DATA FILE PSPP synthetic test file";
469 2; # Layout code
470 28; # Nominal case size
471 0; # Not compressed
472 0; # Not weighted
473 1; # 1 case.
474 100.0; # Bias.
475 "01 Jan 11"; "20:53:52";
476 "PSPP synthetic test file: "; i8 244; i8 245; i8 246; i8 248; s34 "";
477 i8 0 *3;
478 "#;
479         let output = sack(input, None, Endian::Big)?;
480         HexView::new(&output).print()?;
481         Ok(())
482     }
483 }