Introduce Op type.
authorBen Pfaff <blp@cs.stanford.edu>
Tue, 14 Dec 2021 06:09:40 +0000 (22:09 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Tue, 14 Dec 2021 06:11:25 +0000 (22:11 -0800)
src/language/expressions/generate.py

index 4d87ce8f0637fc4275d1ba51c4f743dc6f313197..39fe2632ac5662f273165660387b371a4d4b33e9 100644 (file)
@@ -84,9 +84,9 @@ init_type has 2 required arguments:
           "any": Usable as operands and function arguments, and
           function and operator results.
 
-          "leaf": Usable as operands and function arguments, but
-          not function arguments or results.  (Thus, they appear
-          only in leaf nodes in the parse type.)
+          "leaf": Usable as operands and function arguments, but not
+          results.  (Thus, they appear only in leaf nodes in the parse
+          tree.)
 
           "fixed": Not allowed either as an operand or argument
           type or a result type.  Used only as auxiliary data.
@@ -173,6 +173,32 @@ def c_type(type_):
 \f
 # Input parsing.
 
+class Op:
+    def __init__(self,
+                 name, category,
+                 returns, args, aux,
+                 expression, block,
+                 min_valid,
+                 optimizable, unimplemented, extension, perm_only, absorb_miss, no_abbrev,
+                 opname, mangle):
+        self.name = name
+        self.category = category
+        self.returns = returns
+        self.args = args
+        self.aux = aux
+        self.expression = expression
+        self.block = block
+        self.min_valid = min_valid
+        self.optimizable = optimizable
+        self.unimplemented = unimplemented
+        self.extension = extension
+        self.perm_only = perm_only
+        self.absorb_miss = absorb_miss
+        self.no_abbrev = no_abbrev
+        self.opname = opname
+        if mangle is not None:
+            self.mangle = mangle
+
 def parse_input():
     """Parses the entire input.
 
@@ -196,25 +222,23 @@ def parse_input():
     opers = []
 
     while toktype != 'eof':
-        op = {
-            'OPTIMIZABLE': True,
-            'UNIMPLEMENTED': False,
-            'EXTENSION': False,
-            'PERM_ONLY': False,
-            'ABSORB_MISS': False,
-            'NO_ABBREV': False,
-        }
+        optimizable = True
+        unimplemented = False
+        extension = False
+        perm_only = False
+        absorb_miss = False
+        no_abbrev = False
         while True:
             if match('extension'):
-                op['EXTENSION'] = True
+                extension = True
             elif match('no_opt'):
-                op['OPTIMIZABLE'] = False
+                optimizable = False
             elif match('absorb_miss'):
-                op['ABSORB_MISS'] = True
+                absorb_miss = True
             elif match('perm_only'):
-                op['PERM_ONLY'] = True
+                perm_only = True
             elif match('no_abbrev'):
-                op['NO_ABBREV'] = True
+                no_abbrev = True
             else:
                 break
 
@@ -222,21 +246,20 @@ def parse_input():
         if return_type is None:
             return_type = types['number']
         if return_type.name not in ['number', 'string', 'boolean']:
-            sys.stderr.write('%s is not a valid return type\n' % return_type['NAME'])
+            sys.stderr.write('%s is not a valid return type\n' % return_type.name)
             sys.exit(1)
-        op['RETURNS'] = return_type
 
-        op['CATEGORY'] = token
-        if op['CATEGORY'] not in ['operator', 'function']:
+        category = token
+        if category not in ['operator', 'function']:
             sys.stderr.write("'operator' or 'function' expected at '%s'" % token)
             sys.exit(1)
         get_token()
 
         name = force('id')
-        if op['CATEGORY'] == 'function' and '_' in name:
+        if category == 'function' and '_' in name:
             sys.stderr.write("function name '%s' may not contain underscore\n" % name)
             sys.exit(1)
-        elif op['CATEGORY'] == 'operator' and '.' in name:
+        elif category == 'operator' and '.' in name:
             sys.stderr.write("operator name '%s' may not contain period\n" % name)
             sys.exit(1)
 
@@ -244,17 +267,16 @@ def parse_input():
         if m:
             prefix, suffix = m.groups()
             name = prefix
-            op['MIN_VALID'] = int(suffix)
-            op['ABSORB_MISS'] = True
+            min_valid = int(suffix)
+            absorb_miss = True
         else:
-            op['MIN_VALID'] = 0
-        op['NAME'] = name
+            min_valid = 0
 
         force_match('(')
-        op['ARGS'] = []
+        args = []
         while not match(')'):
             arg = parse_arg()
-            op['ARGS'] += [arg]
+            args += [arg]
             if arg.idx is not None:
                 if match(')'):
                     break
@@ -264,32 +286,20 @@ def parse_input():
                 force_match(')')
                 break
 
-        for arg in op['ARGS']:
+        for arg in args:
             if arg.condition is not None:
-                any_arg = '|'.join([a.name for a in op['ARGS']])
+                any_arg = '|'.join([a.name for a in args])
                 arg.condition = re.sub(r'\b(%s)\b' % any_arg, r'arg_\1', arg.condition)
 
-        opname = 'OP_' + op['NAME']
+        opname = 'OP_' + name
         opname = opname.replace('.', '_')
-        if op['CATEGORY'] == 'function':
-            mangle = ''.join([a.type_.mangle for a in op['ARGS']])
-            op['MANGLE'] = mangle
+        if category == 'function':
+            mangle = ''.join([a.type_.mangle for a in args])
             opname += '_' + mangle
-        op['OPNAME'] = opname
-
-        if op['MIN_VALID'] > 0:
-            aa = array_arg(op)
-            if aa is None:
-                sys.stderr.write("can't have minimum valid count without array arg\n")
-                sys.exit(1)
-            if aa.type_.name != 'number':
-                sys.stderr.write('minimum valid count allowed only with double array\n')
-                sys.exit(1)
-            if aa.times != 1:
-                sys.stderr.write("can't have minimu valid count if array has multiplication factor\n")
-                sys.exit(1)
+        else:
+            mangle = None
 
-        op['AUX'] = []
+        aux = []
         while toktype == 'id':
             type_ = parse_type()
             if type_ is None:
@@ -298,37 +308,39 @@ def parse_input():
             if type_.role not in ['leaf', 'fixed']:
                 sys.stderr.write("'%s' is not allowed as auxiliary data\n" % type_.name)
                 sys.exit(1)
-            name = force('id')
-            op['AUX'] += [{'TYPE': type_, 'NAME': name}]
+            aux_name = force('id')
+            aux += [{'TYPE': type_, 'NAME': aux_name}]
             force_match(';')
 
-        if op['OPTIMIZABLE']:
-            if op['NAME'].startswith('RV.'):
+        if optimizable:
+            if name.startswith('RV.'):
                 sys.stderr.write("random variate functions must be marked 'no_opt'\n")
                 sys.exit(1)
             for key in ['CASE', 'CASE_IDX']:
-                if key in op['AUX']:
+                if key in aux:
                     sys.stderr.write("operators with %s aux data must be marked 'no_opt'\n" % key)
                     sys.exit(1)
 
-        if op['RETURNS'].name == 'string' and not op['ABSORB_MISS']:
-            for arg in op['ARGS']:
+        if return_type.name == 'string' and not absorb_miss:
+            for arg in args:
                 if arg.type_.name in ['number', 'boolean']:
                     sys.stderr.write("'%s' returns string and has double or bool "
                                      "argument, but is not marked ABSORB_MISS\n"
-                                     % op['NAME'])
+                                     % name)
                     sys.exit(1)
                 if arg.condition is not None:
                     sys.stderr.write("'%s' returns string but has argument with condition\n")
                     sys.exit(1)
 
         if toktype == 'block':
-            op['BLOCK'] = force('block')
+            block = force('block')
+            expression = None
         elif toktype == 'expression':
             if token == 'unimplemented':
-                op['UNIMPLEMENTED'] = True
+                unimplemented = True
             else:
-                op['EXPRESSION'] = token
+                expression = token
+            block = None
             get_token()
         else:
             sys.stderr.write("block or expression expected\n")
@@ -337,15 +349,37 @@ def parse_input():
         if opname in ops:
             sys.stderr.write("duplicate operation name %s\n" % opname)
             sys.exit(1)
+
+        op = Op(name, category,
+                return_type, args, aux,
+                expression, block,
+                min_valid,
+                optimizable, unimplemented, extension, perm_only, absorb_miss,
+                no_abbrev,
+                opname, mangle)
+
+        if min_valid > 0:
+            aa = array_arg(op)
+            if aa is None:
+                sys.stderr.write("can't have minimum valid count without array arg\n")
+                sys.exit(1)
+            if aa.type_.name != 'number':
+                sys.stderr.write('minimum valid count allowed only with double array\n')
+                sys.exit(1)
+            if aa.times != 1:
+                sys.stderr.write("can't have minimu valid count if array has multiplication factor\n")
+                sys.exit(1)
+
         ops[opname] = op
-        if op['CATEGORY'] == 'function':
+        if category == 'function':
             funcs += [opname]
         else:
             opers += [opname]
+
     in_file.close()
 
-    funcs = sorted(funcs, key=lambda name: (ops[name]['NAME'], ops[name]['OPNAME']))
-    opers = sorted(opers, key=lambda name: ops[name]['NAME'])
+    funcs = sorted(funcs, key=lambda name: (ops[name].name, ops[name].opname))
+    opers = sorted(opers, key=lambda name: ops[name].name)
     order = funcs + opers
 
 def get_token():
@@ -569,27 +603,27 @@ def generate_evaluate_h():
 
     for opname in order:
         op = ops[opname]
-        if op['UNIMPLEMENTED']:
+        if op.unimplemented:
             continue
 
         args = []
-        for arg in op['ARGS']:
+        for arg in op.args:
             if arg.idx is None:
                 args += [c_type(arg.type_) + arg.name]
             else:
                 args += [c_type(arg.type_) + arg.name + '[]']
                 args += ['size_t %s' % arg.idx]
-        for aux in op['AUX']:
+        for aux in op.aux:
             args += [c_type(aux['TYPE']) + aux['NAME']]
         if not args:
             args += ['void']
 
-        if 'BLOCK' in op:
-            statements = op['BLOCK'] + '\n'
+        if op.block:
+            statements = op.block + '\n'
         else:
-            statements = "  return %s;\n" % op['EXPRESSION']
+            statements = "  return %s;\n" % op.expression
 
-        out_file.write("static inline %s\n" % c_type (op['RETURNS']))
+        out_file.write("static inline %s\n" % c_type (op.returns))
         out_file.write("eval_%s (%s)\n" % (opname, ', '.join(args)))
         out_file.write("{\n")
         out_file.write(statements)
@@ -598,14 +632,14 @@ def generate_evaluate_h():
 def generate_evaluate_inc():
     for opname in order:
         op = ops[opname]
-        if op['UNIMPLEMENTED']:
+        if op.unimplemented:
             out_file.write("case %s:\n" % opname)
             out_file.write("  NOT_REACHED ();\n\n")
             continue
 
         decls = []
         args = []
-        for arg in op['ARGS']:
+        for arg in op.args:
             type_ = arg.type_
             ctype = c_type(type_)
             args += ['arg_%s' % arg.name]
@@ -626,7 +660,7 @@ def generate_evaluate_inc():
                 if arg.times != 1:
                     idx += ' / %s' % arg.times
                 args += [idx]
-        for aux in op['AUX']:
+        for aux in op.aux:
             type_ = aux['TYPE']
             name = aux['NAME']
             if type_.role == 'leaf':
@@ -640,9 +674,9 @@ def generate_evaluate_inc():
         if sysmis_cond is not None:
             decls += [sysmis_cond]
 
-        result = 'eval_%s (%s)' % (op['OPNAME'], ', '.join(args))
+        result = 'eval_%s (%s)' % (op.opname, ', '.join(args))
 
-        stack = op['RETURNS'].stack
+        stack = op.returns.stack
 
         out_file.write("case %s:\n" % opname)
         if decls:
@@ -650,7 +684,7 @@ def generate_evaluate_inc():
             for decl in decls:
                 out_file.write("    %s;\n" % decl)
             if sysmis_cond is not None:
-                miss_ret = op['RETURNS'].missing_value
+                miss_ret = op.returns.missing_value
                 out_file.write("    *%s++ = force_sysmis ? %s : %s;\n" % (stack, miss_ret, result))
             else:
                 out_file.write("    *%s++ = %s;\n" % (stack, result))
@@ -711,14 +745,14 @@ def generate_optimize_inc():
     for opname in order:
         op = ops[opname]
 
-        if not op['OPTIMIZABLE'] or op['UNIMPLEMENTED']:
+        if not op.optimizable or op.unimplemented:
             out_file.write("case %s:\n" % opname)
             out_file.write("  NOT_REACHED ();\n\n")
             continue
 
         decls = []
         arg_idx = 0
-        for arg in op['ARGS']:
+        for arg in op.args:
             name = arg.name
             type_ = arg.type_
             ctype = c_type(type_)
@@ -739,7 +773,7 @@ def generate_optimize_inc():
             decls += [sysmis_cond]
 
         args = []
-        for arg in op['ARGS']:
+        for arg in op.args:
             args += ["arg_%s" % arg.name]
             if arg.idx is not None:
                 idx = 'arg_%s' % arg.idx
@@ -747,7 +781,7 @@ def generate_optimize_inc():
                     idx += " / %s" % arg.times
                 args += [idx]
 
-        for aux in op['AUX']:
+        for aux in op.aux:
             type_ = aux['TYPE']
             if type_.role == 'leaf':
                 func = "get_%s_arg" % type_.atom
@@ -758,14 +792,14 @@ def generate_optimize_inc():
             else:
                 assert False
 
-        result = "eval_%s (%s)" % (op['OPNAME'], ', '.join(args))
+        result = "eval_%s (%s)" % (op.opname, ', '.join(args))
         if decls and sysmis_cond is not None:
-            miss_ret = op['RETURNS'].missing_value
-            decls += ['%sresult = force_sysmis ? %s : %s' % (c_type(op['RETURNS']), miss_ret, result)]
+            miss_ret = op.returns.missing_value
+            decls += ['%sresult = force_sysmis ? %s : %s' % (c_type(op.returns), miss_ret, result)]
             result = 'result'
 
         out_file.write("case %s:\n" % opname)
-        alloc_func = "expr_allocate_%s" % op['RETURNS'].name
+        alloc_func = "expr_allocate_%s" % op.returns.name
         if decls:
             out_file.write("  {\n")
             for decl in decls:
@@ -789,28 +823,28 @@ def generate_parse_inc():
         op = ops[opname]
 
         members = []
-        members += ['"%s"' % op['NAME']]
+        members += ['"%s"' % op.name]
 
-        if op['CATEGORY'] == 'function':
+        if op.category == 'function':
             args = []
             opt_args = []
-            for arg in op['ARGS']:
+            for arg in op.args:
                 if arg.idx is None:
                     args += [arg.type_.human_name]
 
             array = array_arg(op)
             if array is not None:
-                if op['MIN_VALID'] == 0:
+                if op.min_valid == 0:
                     array_args = []
                     for i in range(array.times):
                         array_args += [array.type_.human_name]
                     args += array_args
                     opt_args = array_args
                 else:
-                    for i in range(op['MIN_VALID']):
+                    for i in range(op.min_valid):
                         args += [array.type_.human_name]
                     opt_args += [array.type_.human_name]
-            human = "%s(%s" % (op['NAME'], ', '.join(args))
+            human = "%s(%s" % (op.name, ', '.join(args))
             if opt_args:
                 human += '[, %s]...' % ', '.join(opt_args)
             human += ')'
@@ -819,32 +853,32 @@ def generate_parse_inc():
             members += ['NULL']
 
         flags = []
-        if op['ABSORB_MISS']:
+        if op.absorb_miss:
             flags += ['OPF_ABSORB_MISS']
         if array_arg(op):
             flags += ['OPF_ARRAY_OPERAND']
-        if op['MIN_VALID'] > 0:
+        if op.min_valid > 0:
             flags += ['OPF_MIN_VALID']
-        if not op['OPTIMIZABLE']:
+        if not op.optimizable:
             flags += ['OPF_NONOPTIMIZABLE']
-        if op['EXTENSION']:
+        if op.extension:
             flags += ['OPF_EXTENSION']
-        if op['UNIMPLEMENTED']:
+        if op.unimplemented:
             flags += ['OPF_UNIMPLEMENTED']
-        if op['PERM_ONLY']:
+        if op.perm_only:
             flags += ['OPF_PERM_ONLY']
-        if op['NO_ABBREV']:
+        if op.no_abbrev:
             flags += ['OPF_NO_ABBREV']
         members += [' | '.join(flags) if flags else '0']
 
-        members += ['OP_%s' % op['RETURNS'].name]
+        members += ['OP_%s' % op.returns.name]
 
-        members += ['%s' % len(op['ARGS'])]
+        members += ['%s' % len(op.args)]
 
-        arg_types = ["OP_%s" % arg.type_.name for arg in op['ARGS']]
+        arg_types = ["OP_%s" % arg.type_.name for arg in op.args]
         members += ['{%s}' % ', '.join(arg_types)]
 
-        members += ['%s' % op['MIN_VALID']]
+        members += ['%s' % op.min_valid]
 
         members += ['%s' % (array_arg(op).times if array_arg(op) else 0)]
 
@@ -863,8 +897,8 @@ def make_sysmis_decl(op, min_valid_src):
 
     """
     sysmis_cond = []
-    if not op['ABSORB_MISS']:
-        for arg in op['ARGS']:
+    if not op.absorb_miss:
+        for arg in op.args:
             arg_name = 'arg_%s' % arg.name
             if arg.idx is None:
                 if arg.type_.name in ['number', 'boolean']:
@@ -873,13 +907,13 @@ def make_sysmis_decl(op, min_valid_src):
                 a = arg_name
                 n = 'arg_%s' % arg.idx
                 sysmis_cond += ['count_valid (%s, %s) < %s' % (a, n, n)]
-    elif op['MIN_VALID'] > 0:
-        args = op['ARGS']
+    elif op.min_valid > 0:
+        args = op.args
         arg = args[-1]
         a = 'arg_%s' % arg.name
         n = 'arg_%s' % arg.idx
         sysmis_cond += ["count_valid (%s, %s) < %s" % (a, n, min_valid_src)]
-    for arg in op['ARGS']:
+    for arg in op.args:
         if arg.condition is not None:
             sysmis_cond += ['!(%s)' % arg.condition]
     if sysmis_cond:
@@ -889,7 +923,7 @@ def make_sysmis_decl(op, min_valid_src):
 def array_arg(op):
     """If 'op' has an array argument, returns it.  Otherwise, returns
     None."""
-    args = op['ARGS']
+    args = op.args
     if not args:
         return None
     last_arg = args[-1]