41012c467e4eb40a880333cfbeb28a9d2c141b8e
[pspp] / rust / src / sack.rs
1 use anyhow::{anyhow, Result};
2 use float_next_after::NextAfter;
3 use num::{Bounded, Zero};
4 use ordered_float::OrderedFloat;
5 use std::{
6     collections::{hash_map::Entry, HashMap},
7     fmt::Display,
8     iter::{repeat, Peekable},
9     str::Chars,
10 };
11
12 use crate::endian::{Endian, ToBytes};
13
14 pub fn sack(input: &str, endian: Endian) -> Result<Vec<u8>> {
15     let mut symbol_table = HashMap::new();
16     let output = _sack(input, endian, &mut symbol_table)?;
17     let output = if !symbol_table.is_empty() {
18         for (k, v) in symbol_table.iter() {
19             if v.is_none() {
20                 return Err(anyhow!("label {k} used but never defined"));
21             }
22         }
23         _sack(input, endian, &mut symbol_table)?
24     } else {
25         output
26     };
27     Ok(output)
28 }
29
30 fn _sack(
31     input: &str,
32     endian: Endian,
33     symbol_table: &mut HashMap<String, Option<u32>>,
34 ) -> Result<Vec<u8>> {
35     let mut lexer = Lexer::new(input, endian)?;
36     let mut output = Vec::new();
37     while parse_data_item(&mut lexer, &mut output, symbol_table)? {}
38     Ok(output)
39 }
40
41 fn parse_data_item(
42     lexer: &mut Lexer,
43     output: &mut Vec<u8>,
44     symbol_table: &mut HashMap<String, Option<u32>>,
45 ) -> Result<bool> {
46     if lexer.token.is_none() {
47         return Ok(false);
48     };
49
50     let initial_len = output.len();
51     match lexer.take()? {
52         Token::Integer(integer) => output.extend_from_slice(&lexer.endian.to_bytes(integer)),
53         Token::Float(float) => output.extend_from_slice(&lexer.endian.to_bytes(float.0)),
54         Token::PcSysmis => {
55             output.extend_from_slice(&[0xf5, 0x1e, 0x26, 0x02, 0x8a, 0x8c, 0xed, 0xff])
56         }
57         Token::I8 => put_integers::<u8, 1>(lexer, "i8", output)?,
58         Token::I16 => put_integers::<u16, 2>(lexer, "i16", output)?,
59         Token::I64 => put_integers::<i64, 8>(lexer, "i64", output)?,
60         Token::String(string) => output.extend_from_slice(string.as_bytes()),
61         Token::S(size) => {
62             let Some(Token::String(ref string)) = lexer.token else {
63                 return Err(anyhow!("string expected after 's{size}'"));
64             };
65             let len = string.len();
66             if len > size {
67                 return Err(anyhow!(
68                     "{len}-byte string is longer than pad length {size}"
69                 ));
70             }
71             output.extend_from_slice(string.as_bytes());
72             output.extend(repeat(b' ').take(size - len));
73             lexer.get()?;
74         }
75         Token::LParen => {
76             while lexer.token != Some(Token::RParen) {
77                 parse_data_item(lexer, output, symbol_table)?;
78             }
79             lexer.get()?;
80         }
81         Token::Count => put_counted_items::<u32, 4>(lexer, "COUNT", output, symbol_table)?,
82         Token::Count8 => put_counted_items::<u8, 1>(lexer, "COUNT8", output, symbol_table)?,
83         Token::Hex => {
84             let Some(Token::String(ref string)) = lexer.token else {
85                 return Err(anyhow!("string expected after 'hex'"));
86             };
87             let mut i = string.chars();
88             loop {
89                 let Some(c0) = i.next() else { return Ok(true) };
90                 let Some(c1) = i.next() else {
91                     return Err(anyhow!("hex string has odd number of characters"));
92                 };
93                 let (Some(digit0), Some(digit1)) = (c0.to_digit(16), c1.to_digit(16)) else {
94                     return Err(anyhow!("invalid digit in hex string"));
95                 };
96                 let byte = digit0 * 16 + digit1;
97                 output.push(byte as u8);
98             }
99         }
100         Token::Label(name) => {
101             let value = output.len() as u32;
102             match symbol_table.entry(name) {
103                 Entry::Vacant(v) => {
104                     v.insert(Some(value));
105                 }
106                 Entry::Occupied(o) => {
107                     if let Some(v) = o.get() {
108                         if *v != value {
109                             return Err(anyhow!("syntax error"));
110                         }
111                     }
112                 }
113             };
114         }
115         Token::At(name) => {
116             let mut value = symbol_table.entry(name).or_insert(None).unwrap_or(0);
117             lexer.get()?;
118             loop {
119                 let plus = match lexer.token {
120                     Some(Token::Plus) => true,
121                     Some(Token::Minus) => false,
122                     _ => break,
123                 };
124                 lexer.get()?;
125
126                 let operand = match lexer.token {
127                     Some(Token::At(ref name)) => if let Some(value) = symbol_table.get(name) {
128                         *value
129                     } else {
130                         symbol_table.insert(name.clone(), None);
131                         None
132                     }
133                     .unwrap_or(0),
134                     Some(Token::Integer(integer)) => integer
135                         .try_into()
136                         .map_err(|msg| anyhow!("bad offset literal ({msg})"))?,
137                     _ => return Err(anyhow!("expecting @label or integer literal")),
138                 };
139                 lexer.get()?;
140
141                 value = if plus {
142                     value.checked_add(operand)
143                 } else {
144                     value.checked_sub(operand)
145                 }
146                 .ok_or_else(|| anyhow!("overflow in offset arithmetic"))?;
147             }
148             output.extend_from_slice(&lexer.endian.to_bytes(value));
149         }
150         _ => (),
151     };
152     if lexer.token == Some(Token::Asterisk) {
153         lexer.get()?;
154         let Token::Integer(count) = lexer.take()? else {
155             return Err(anyhow!("positive integer expected after '*'"));
156         };
157         if count < 1 {
158             return Err(anyhow!("positive integer expected after '*'"));
159         };
160         let final_len = output.len();
161         for _ in 1..count {
162             output.extend_from_within(initial_len..final_len);
163         }
164     }
165     match lexer.token {
166         Some(Token::Semicolon) => {
167             lexer.get()?;
168         }
169         Some(Token::RParen) => (),
170         _ => return Err(anyhow!("';' expected")),
171     }
172     Ok(true)
173 }
174
175 fn put_counted_items<T, const N: usize>(
176     lexer: &mut Lexer,
177     name: &str,
178     output: &mut Vec<u8>,
179     symbol_table: &mut HashMap<String, Option<u32>>,
180 ) -> Result<()>
181 where
182     T: Zero + TryFrom<usize>,
183     Endian: ToBytes<T, N>,
184 {
185     let old_size = output.len();
186     output.extend_from_slice(&lexer.endian.to_bytes(T::zero()));
187     if lexer.token != Some(Token::LParen) {
188         return Err(anyhow!("'(' expected after '{name}'"));
189     }
190     lexer.get()?;
191     while lexer.token != Some(Token::RParen) {
192         parse_data_item(lexer, output, symbol_table)?;
193     }
194     lexer.get()?;
195     let delta = output.len() - old_size;
196     let Ok(delta): Result<T, _> = delta.try_into() else {
197         return Err(anyhow!("{delta} bytes is too much for '{name}'"));
198     };
199     let dest = &mut output[old_size..old_size + N];
200     dest.copy_from_slice(&lexer.endian.to_bytes(delta));
201     Ok(())
202 }
203
204 fn put_integers<T, const N: usize>(
205     lexer: &mut Lexer,
206     name: &str,
207     output: &mut Vec<u8>,
208 ) -> Result<()>
209 where
210     T: Bounded + Display + TryFrom<i64> + Copy,
211     Endian: ToBytes<T, N>,
212 {
213     let mut n = 0;
214     while let Some(integer) = lexer.take_if(|t| match t {
215         Token::Integer(integer) => Some(*integer),
216         _ => None,
217     })? {
218         let Ok(integer) = integer.try_into() else {
219             return Err(anyhow!(
220                 "{integer} is not in the valid range [{},{}]",
221                 T::min_value(),
222                 T::max_value()
223             ));
224         };
225         output.extend_from_slice(&lexer.endian.to_bytes(integer));
226         n += 1;
227     }
228     if n == 0 {
229         return Err(anyhow!("integer expected after '{name}'"));
230     }
231     Ok(())
232 }
233
234 #[derive(PartialEq, Eq, Clone)]
235 enum Token {
236     Integer(i64),
237     Float(OrderedFloat<f64>),
238     PcSysmis,
239     String(String),
240     Semicolon,
241     Asterisk,
242     LParen,
243     RParen,
244     I8,
245     I16,
246     I64,
247     S(usize),
248     Count,
249     Count8,
250     Hex,
251     Label(String),
252     At(String),
253     Minus,
254     Plus,
255 }
256
257 struct Lexer<'a> {
258     iter: Peekable<Chars<'a>>,
259     token: Option<Token>,
260     line_number: usize,
261     endian: Endian,
262 }
263
264 impl<'a> Lexer<'a> {
265     fn new(input: &'a str, endian: Endian) -> Result<Lexer<'a>> {
266         let mut lexer = Lexer {
267             iter: input.chars().peekable(),
268             token: None,
269             line_number: 1,
270             endian,
271         };
272         lexer.next()?;
273         Ok(lexer)
274     }
275     fn take(&mut self) -> Result<Token> {
276         let Some(token) = self.token.take() else {
277             return Err(anyhow!("unexpected end of input"));
278         };
279         self.token = self.next()?;
280         Ok(token)
281     }
282     fn take_if<F, T>(&mut self, condition: F) -> Result<Option<T>>
283     where
284         F: FnOnce(&Token) -> Option<T>,
285     {
286         let Some(ref token) = self.token else {
287             return Ok(None);
288         };
289         match condition(token) {
290             Some(value) => {
291                 self.token = self.next()?;
292                 Ok(Some(value))
293             }
294             None => Ok(None),
295         }
296     }
297     fn get(&mut self) -> Result<Option<&Token>> {
298         if self.token.is_none() {
299             Err(anyhow!("unexpected end of input"))
300         } else {
301             self.token = self.next()?;
302             Ok((&self.token).into())
303         }
304     }
305
306     fn next(&mut self) -> Result<Option<Token>> {
307         // Get the first character of the token, skipping past white space and
308         // comments.
309         let c = loop {
310             let Some(c) = self.iter.next() else {
311                 return Ok(None);
312             };
313             let c = if c == '#' {
314                 loop {
315                     match self.iter.next() {
316                         None => return Ok(None),
317                         Some('\n') => break,
318                         _ => (),
319                     }
320                 }
321                 '\n'
322             } else {
323                 c
324             };
325             if c == '\n' {
326                 self.line_number += 1
327             } else if !c.is_whitespace() && c != '<' && c != '>' {
328                 break c;
329             }
330         };
331
332         let token = match c {
333             c if c.is_ascii_digit() || c == '-' => {
334                 let mut s = String::from(c);
335                 while let Some(c) = self
336                     .iter
337                     .next_if(|&c| c.is_ascii_digit() || c.is_alphabetic() || c == '.')
338                 {
339                     s.push(c);
340                 }
341
342                 if s == "-" {
343                     Token::Minus
344                 } else if !s.contains('.') {
345                     Token::Integer(
346                         s.parse()
347                             .map_err(|msg| anyhow!("bad integer literal '{s}' ({msg})"))?,
348                     )
349                 } else {
350                     Token::Float(
351                         s.parse()
352                             .map_err(|msg| anyhow!("bad float literal '{s}' ({msg})"))?,
353                     )
354                 }
355             }
356             '"' => {
357                 let mut s = String::from(c);
358                 loop {
359                     match self.iter.next() {
360                         None => return Err(anyhow!("end-of-file inside string")),
361                         Some('\n') => return Err(anyhow!("new-line inside string")),
362                         Some('"') => break,
363                         Some(c) => s.push(c),
364                     }
365                 }
366                 Token::String(s)
367             }
368             ';' => Token::Semicolon,
369             '*' => Token::Asterisk,
370             '+' => Token::Plus,
371             '(' => Token::LParen,
372             ')' => Token::RParen,
373             c if c.is_alphabetic() || c == '@' || c == '_' => {
374                 let mut s = String::from(c);
375                 while let Some(c) = self
376                     .iter
377                     .next_if(|&c| c.is_ascii_digit() || c.is_alphabetic() || c == '.' || c == '_')
378                 {
379                     s.push(c);
380                 }
381                 if self.iter.next_if_eq(&':').is_some() {
382                     Token::Label(s)
383                 } else if s.starts_with('@') {
384                     Token::At(s)
385                 } else if let Some(count) = s.strip_prefix('s') {
386                     Token::S(
387                         count
388                             .parse()
389                             .map_err(|msg| anyhow!("bad counted string '{s}' ({msg})"))?,
390                     )
391                 } else {
392                     match &s[..] {
393                         "i8" => Token::I8,
394                         "i16" => Token::I16,
395                         "i64" => Token::I64,
396                         "SYSMIS" => Token::Float(OrderedFloat(-f64::MAX)),
397                         "PCSYSMIS" => Token::PcSysmis,
398                         "LOWEST" => Token::Float((-f64::MAX).next_after(0.0).into()),
399                         "HIGHEST" => Token::Float(f64::MAX.into()),
400                         "ENDIAN" => Token::Integer(if self.endian == Endian::Big { 1 } else { 2 }),
401                         "COUNT" => Token::Count,
402                         "COUNT8" => Token::Count8,
403                         "hex" => Token::Hex,
404                         _ => return Err(anyhow!("invalid token '{s}'")),
405                     }
406                 }
407             }
408             _ => return Err(anyhow!("invalid input byte '{c}'")),
409         };
410         Ok(Some(token))
411     }
412 }