Introduce Arg class.
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 14 Dec 2021 05:21:37 +0000 (21:21 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Tue, 14 Dec 2021 05:21:37 +0000 (21:21 -0800)
src/language/expressions/generate.py

index ef7e2ab7e748ce6e163e5cc51be14ea3e7ee5fa4..6cb6a9f257370e5be8ac4132798a445e0009968a 100644 (file)
@@ -272,7 +272,7 @@ def parse_input():
         while not match(')'):
             arg = parse_arg()
             op['ARGS'] += [arg]
-            if 'IDX' in arg:
+            if arg.idx is not None:
                 if match(')'):
                     break
                 sys.stderr.write('array must be last argument\n')
@@ -282,16 +282,14 @@ def parse_input():
                 break
 
         for arg in op['ARGS']:
-            if 'CONDITION' in arg:
-                any_arg = '|'.join([a['NAME'] for a in op['ARGS']])
-                arg['CONDITION'] = re.sub(r'\b(%s)\b' % any_arg, r'arg_\1',
-                                          arg['CONDITION'])
+            if arg.condition is not None:
+                any_arg = '|'.join([a.name for a in op['ARGS']])
+                arg.condition = re.sub(r'\b(%s)\b' % any_arg, r'arg_\1', arg.condition)
 
         opname = 'OP_' + op['NAME']
         opname = opname.replace('.', '_')
         if op['CATEGORY'] == 'function':
-            print(op)
-            mangle = ''.join([a['TYPE']['MANGLE'] for a in op['ARGS']])
+            mangle = ''.join([a.type_['MANGLE'] for a in op['ARGS']])
             op['MANGLE'] = mangle
             opname += '_' + mangle
         op['OPNAME'] = opname
@@ -301,10 +299,10 @@ def parse_input():
             if aa is None:
                 sys.stderr.write("can't have minimum valid count without array arg\n")
                 sys.exit(1)
-            if aa['TYPE']['NAME'] != 'number':
+            if aa.type_['NAME'] != 'number':
                 sys.stderr.write('minimum valid count allowed only with double array\n')
                 sys.exit(1)
-            if aa['TIMES'] != 1:
+            if aa.times != 1:
                 sys.stderr.write("can't have minimu valid count if array has multiplication factor\n")
                 sys.exit(1)
 
@@ -333,12 +331,12 @@ def parse_input():
 
         if op['RETURNS']['NAME'] == 'string' and not op['ABSORB_MISS']:
             for arg in op['ARGS']:
-                if arg['TYPE']['NAME'] in ['number', 'boolean']:
+                if arg.type_['NAME'] in ['number', 'boolean']:
                     sys.stderr.write("'%s' returns string and has double or bool "
                                      "argument, but is not marked ABSORB_MISS\n"
                                      % op['NAME'])
                     sys.exit(1)
-                if 'CONDITION' in arg:
+                if arg.condition is not None:
                     sys.stderr.write("'%s' returns string but has argument with condition\n")
                     sys.exit(1)
 
@@ -379,7 +377,6 @@ def get_token():
     if toktype == 'eof':
         return
 
-    print('%s %s' % (line, toktype))
     m = re.match(r'([a-zA-Z_][a-zA-Z_.0-9]*)(.*)$', line)
     if m:
         token, line = m.groups()
@@ -444,11 +441,8 @@ def accumulate_balanced(end, swallow_end=True):
     s = ""
     nest = 0
     global line
-    print("line='%s'" % line)
     while True:
-        print(type(line))
         for idx, c in enumerate(line):
-            print('nest=%s %s end=%s' % (nest, c, end))
             if c in end and nest == 0:
                 line = line[idx:]
                 if swallow_end:
@@ -473,7 +467,6 @@ def get_line():
     global line_number
     line = in_file.readline()
     line_number += 1
-    print("%s\n" % line_number)
     if line == '':
         line = None
     else:
@@ -523,41 +516,50 @@ def force_match(tok):
         sys.stderr.write("parse error at `%s' expecting `%s'\n" % (token, tok))
         sys.exit(1)
 
+class Arg:
+    def __init__(self, name, type_, idx, times, condition):
+        self.name = name
+        self.type_ = type_
+        self.idx = idx
+        self.times = times
+        self.condition = condition
+
 def parse_arg():
     """Parses and returns a function argument."""
-    arg = {}
-    arg['TYPE'] = parse_type()
-    if arg['TYPE'] is None:
-        arg['TYPE'] = types['number']
+    type_ = parse_type()
+    if type_ is None:
+        type_ = types['number']
 
     if toktype != 'id':
         sys.stderr.write("argument name expected at `%s'\n" % token)
         sys.exit(1)
-    arg['NAME'] = token
+    name = token
 
     lookahead()
     global line
-    print("line[0]=%s" % line[0])
+
+    idx = None
+    times = 1
+
     if line[0] in "[,)":
         get_token()
-        print('token=%s toktype=%s' % (token, toktype))
         if match('['):
-            if arg['TYPE']['NAME'] not in ('number', 'string'):
+            if type_['NAME'] not in ('number', 'string'):
                 sys.stderr.write('only double and string arrays supported\n')
                 sys.exit(1)
-            arg['IDX'] = force('id')
+            idx = force('id')
             if match('*'):
-                arg['TIMES'] = force('int')
-                if arg['TIMES'] != 2:
+                times = force('int')
+                if times != 2:
                     sys.stderr.write('multiplication factor must be two\n')
                     sys.exit(1)
-            else:
-                arg['TIMES'] = 1
             force_match(']')
+        condition = None
     else:
-        arg['CONDITION'] = arg['NAME'] + ' ' + accumulate_balanced(',)', swallow_end=False)
+        condition = name + ' ' + accumulate_balanced(',)', swallow_end=False)
         get_token()
-    return arg
+
+    return Arg(name, type_, idx, times, condition)
 
 def print_header():
     """Prints the output file header."""
@@ -590,11 +592,11 @@ def generate_evaluate_h():
 
         args = []
         for arg in op['ARGS']:
-            if 'IDX' not in arg:
-                args += [c_type(arg['TYPE']) + arg['NAME']]
+            if arg.idx is None:
+                args += [c_type(arg.type_) + arg.name]
             else:
-                args += [c_type(arg['TYPE']) + arg['NAME'] + '[]']
-                args += ['size_t %s' % arg['IDX']]
+                args += [c_type(arg.type_) + arg.name + '[]']
+                args += ['size_t %s' % arg.idx]
         for aux in op['AUX']:
             args += [c_type(aux['TYPE']) + aux['NAME']]
         if not args:
@@ -622,12 +624,11 @@ def generate_evaluate_inc():
         decls = []
         args = []
         for arg in op['ARGS']:
-            name = arg['NAME']
-            type_ = arg['TYPE']
+            type_ = arg.type_
             ctype = c_type(type_)
-            args += ['arg_%s' % name]
-            if 'IDX' not in arg:
-                decl = '%sarg_%s' % (ctype, name)
+            args += ['arg_%s' % arg.name]
+            if arg.idx is None:
+                decl = '%sarg_%s' % (ctype, arg.name)
                 if type_['ROLE'] == 'any':
                     decls = ['%s = *--%s' % (decl, type_['STACK'])] + decls
                 elif type_['ROLE'] == 'leaf':
@@ -635,14 +636,14 @@ def generate_evaluate_inc():
                 else:
                     assert False
             else:
-                idx = arg['IDX']
+                idx = arg.idx
                 stack = type_['STACK']
-                decls = ['%s*arg_%s = %s -= arg_%s' % (ctype, name, stack, idx)] + decls
+                decls = ['%s*arg_%s = %s -= arg_%s' % (ctype, arg.name, stack, idx)] + decls
                 decls = ['size_t arg_%s = op++->integer' % idx] + decls
 
                 idx = 'arg_%s' % idx
-                if arg['TIMES'] != 1:
-                    idx += ' / %s' % arg['TIMES']
+                if arg.times != 1:
+                    idx += ' / %s' % arg.times
                 args += [idx]
         for aux in op['AUX']:
             type_ = aux['TYPE']
@@ -658,7 +659,6 @@ def generate_evaluate_inc():
         if sysmis_cond is not None:
             decls += [sysmis_cond]
 
-        print(args)
         result = 'eval_%s (%s)' % (op['OPNAME'], ', '.join(args))
 
         stack = op['RETURNS']['STACK']
@@ -738,20 +738,19 @@ def generate_optimize_inc():
         decls = []
         arg_idx = 0
         for arg in op['ARGS']:
-            name = arg['NAME']
-            type_ = arg['TYPE']
+            name = arg.name
+            type_ = arg.type_
             ctype = c_type(type_)
-            if not 'IDX' in arg:
+            if arg.idx is None:
                 func = "get_%s_arg" % type_['ATOM']
                 decls += ["%sarg_%s = %s (node, %s)" % (ctype, name, func, arg_idx)]
             else:
-                idx = arg['IDX']
-                decl = "size_t arg_%s = node->n_args" % idx
+                decl = "size_t arg_%s = node->n_args" % arg.idx
                 if arg_idx > 0:
                     decl += " - %s" % arg_idx
                 decls += [decl]
 
-                decls += ["%s*arg_%s = get_%s_args  (node, %s, arg_%s, e)" % (ctype, name, type_['ATOM'], arg_idx, idx)]
+                decls += ["%s*arg_%s = get_%s_args  (node, %s, arg_%s, e)" % (ctype, name, type_['ATOM'], arg_idx, arg.idx)]
             arg_idx += 1
 
         sysmis_cond = make_sysmis_decl (op, "node->min_valid")
@@ -760,11 +759,11 @@ def generate_optimize_inc():
 
         args = []
         for arg in op['ARGS']:
-            args += ["arg_%s" % arg['NAME']]
-            if 'IDX' in arg:
-                idx = 'arg_%s' % arg['IDX']
-                if arg['TIMES'] != 1:
-                    idx += " / %s" % arg['TIMES']
+            args += ["arg_%s" % arg.name]
+            if arg.idx is not None:
+                idx = 'arg_%s' % arg.idx
+                if arg.times != 1:
+                    idx += " / %s" % arg.times
                 args += [idx]
 
         for aux in op['AUX']:
@@ -816,21 +815,21 @@ def generate_parse_inc():
             args = []
             opt_args = []
             for arg in op['ARGS']:
-                if 'IDX' not in arg:
-                    args += [arg['TYPE']['HUMAN_NAME']]
+                if arg.idx is None:
+                    args += [arg.type_['HUMAN_NAME']]
 
             array = array_arg(op)
             if array is not None:
                 if op['MIN_VALID'] == 0:
                     array_args = []
-                    for i in range(array['TIMES']):
-                        array_args += [array['TYPE']['HUMAN_NAME']]
+                    for i in range(array.times):
+                        array_args += [array.type_['HUMAN_NAME']]
                     args += array_args
                     opt_args = array_args
                 else:
                     for i in range(op['MIN_VALID']):
-                        args += [array['TYPE']['HUMAN_NAME']]
-                    opt_args += [array['TYPE']['HUMAN_NAME']]
+                        args += [array.type_['HUMAN_NAME']]
+                    opt_args += [array.type_['HUMAN_NAME']]
             human = "%s(%s" % (op['NAME'], ', '.join(args))
             if opt_args:
                 human += '[, %s]...' % ', '.join(opt_args)
@@ -862,12 +861,12 @@ def generate_parse_inc():
 
         members += ['%s' % len(op['ARGS'])]
 
-        arg_types = ["OP_%s" % arg['TYPE']['NAME'] for arg in op['ARGS']]
+        arg_types = ["OP_%s" % arg.type_['NAME'] for arg in op['ARGS']]
         members += ['{%s}' % ', '.join(arg_types)]
 
         members += ['%s' % op['MIN_VALID']]
 
-        members += ['%s' % (array_arg(op)['TIMES'] if array_arg(op) else 0)]
+        members += ['%s' % (array_arg(op).times if array_arg(op) else 0)]
 
         out_file.write('{%s},\n' % ', '.join(members))
 \f
@@ -886,23 +885,23 @@ def make_sysmis_decl(op, min_valid_src):
     sysmis_cond = []
     if not op['ABSORB_MISS']:
         for arg in op['ARGS']:
-            arg_name = 'arg_%s' % arg['NAME']
-            if 'IDX' not in arg:
-                if arg['TYPE']['NAME'] in ['number', 'boolean']:
+            arg_name = 'arg_%s' % arg.name
+            if arg.idx is None:
+                if arg.type_['NAME'] in ['number', 'boolean']:
                     sysmis_cond += ["!is_valid (%s)" % arg_name]
-            elif arg['TYPE']['NAME'] == 'number':
+            elif arg.type_['NAME'] == 'number':
                 a = arg_name
-                n = 'arg_%s' % arg['IDX']
+                n = 'arg_%s' % arg.idx
                 sysmis_cond += ['count_valid (%s, %s) < %s' % (a, n, n)]
     elif op['MIN_VALID'] > 0:
         args = op['ARGS']
         arg = args[-1]
-        a = 'arg_%s' % arg['NAME']
-        n = 'arg_%s' % arg['IDX']
+        a = 'arg_%s' % arg.name
+        n = 'arg_%s' % arg.idx
         sysmis_cond += ["count_valid (%s, %s) < %s" % (a, n, min_valid_src)]
     for arg in op['ARGS']:
-        if 'CONDITION' in arg:
-            sysmis_cond += ['!(%s)' % arg['CONDITION']]
+        if arg.condition is not None:
+            sysmis_cond += ['!(%s)' % arg.condition]
     if sysmis_cond:
         return 'bool force_sysmis = %s' % ' || '.join(sysmis_cond)
     return None
@@ -914,7 +913,7 @@ def array_arg(op):
     if not args:
         return None
     last_arg = args[-1]
-    if 'IDX' in last_arg:
+    if last_arg.idx is not None:
         return last_arg
     return None