From 00f158b3ed84054bc0a4d193637f082e761f5cf1 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 24 Mar 2018 17:14:27 -0500
Subject: [PATCH] Started adding the reduction interface

---
 loopy/kernel/creation.py           |  69 ++++++++++++--
 loopy/kernel/function_interface.py | 142 +++++++++++++++++++++++------
 loopy/kernel/reduction_callable.py |  85 +++++++++++++++++
 loopy/library/reduction.py         |   7 ++
 loopy/symbolic.py                  |  49 +++++-----
 5 files changed, 293 insertions(+), 59 deletions(-)
 create mode 100644 loopy/kernel/reduction_callable.py

diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 124984ea3..5a6423220 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -1832,7 +1832,7 @@ def apply_single_writer_depencency_heuristic(kernel, warn_if_used=True):
 # }}}
 
 
-# {{{ lookup functions
+# {{{ scope functions
 
 class FunctionScoper(IdentityMapper):
     """
@@ -1880,6 +1880,29 @@ class FunctionScoper(IdentityMapper):
         # This is an unknown function as of yet, not modifying it.
         return IdentityMapper.map_call(self, expr)
 
+    def map_reduction(self, expr):
+        from pymbolic.primitives import Variable
+        from loopy.symbolic import ScopedFunction
+
+        mapped_inames = [self.rec(Variable(iname)) for iname in expr.inames]
+
+        new_inames = []
+        for iname, new_sym_iname in zip(expr.inames, mapped_inames):
+            if not isinstance(new_sym_iname, Variable):
+                from loopy.diagnostic import LoopyError
+                raise LoopyError("%s did not map iname '%s' to a variable"
+                        % (type(self).__name__, iname))
+
+            new_inames.append(new_sym_iname.name)
+
+        from loopy.symbolic import Reduction
+
+        return Reduction(
+                ScopedFunction(expr.operation.name),
+                tuple(new_inames),
+                self.rec(expr.expr),
+                allow_simultaneous=expr.allow_simultaneous)
+
 
 class ScopedFunctionCollector(CombineMapper):
     """ This mapper would collect all the instances of :class:`ScopedFunction`
@@ -1890,7 +1913,44 @@ class ScopedFunctionCollector(CombineMapper):
         return reduce(operator.or_, values, frozenset())
 
     def map_scoped_function(self, expr):
-        return frozenset([expr.name])
+        from loopy.kernel.function_interface import CallableOnScalar
+        return frozenset([(expr.name, CallableOnScalar(expr.name))])
+
+    def map_reduction(self, expr):
+        from loopy.kernel.reduction_callable import CallableReduction
+        from loopy.symbolic import Reduction
+
+        callable_reduction = CallableReduction(expr.operation.name)
+
+        # sanity checks
+
+        if isinstance(expr.expr, tuple):
+            num_args = len(expr.expr)
+        else:
+            num_args = 1
+
+        if num_args != callable_reduction.operation.arg_count:
+            raise RuntimeError("invalid invocation of "
+                    "reduction operation '%s': expected %d arguments, "
+                    "got %d instead" % (expr.function.name,
+                                        callable_reduction.operation.arg_count,
+                                        len(expr.parameters)))
+
+        if callable_reduction.operation.arg_count > 1:
+            from pymbolic.primitives import Call
+
+            if not isinstance(expr, (tuple, Reduction, Call)):
+                raise LoopyError("reduction argument must be one of "
+                                 "a tuple, reduction, or call; "
+                                 "got '%s'" % type(expr).__name__)
+        else:
+            if isinstance(expr, tuple):
+                raise LoopyError("got a tuple argument to a scalar reduction")
+            elif isinstance(expr, Reduction) and callable_reduction.is_tuple_typed:
+                raise LoopyError("got a tuple typed argument to a scalar reduction")
+
+        return frozenset([(expr.operation.name,
+            callable_reduction)])
 
     def map_constant(self, expr):
         return frozenset()
@@ -1921,10 +1981,7 @@ def scope_functions(kernel):
                     type(insn))
 
     # Need to combine the scoped functions into a dict
-    from loopy.kernel.function_interface import CallableOnScalar
-    scoped_function_dict = dict((func, CallableOnScalar(func)) for func in
-            scoped_functions)
-
+    scoped_function_dict = dict(scoped_functions)
     return kernel.copy(instructions=new_insns, scoped_functions=scoped_function_dict)
 
 # }}}
diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py
index bbd6e43cc..a87c1670a 100644
--- a/loopy/kernel/function_interface.py
+++ b/loopy/kernel/function_interface.py
@@ -134,8 +134,7 @@ class InKernelCallable(ImmutableRecord):
 
     """
 
-    def __init__(self, name, subkernel=None, arg_id_to_dtype=None,
-            arg_id_to_descr=None, name_in_target=None):
+    def __init__(self, name, arg_id_to_dtype=None, arg_id_to_descr=None):
 
         # {{{ sanity checks
 
@@ -144,14 +143,9 @@ class InKernelCallable(ImmutableRecord):
 
         # }}}
 
-        if name_in_target is not None and subkernel is not None:
-            subkernel = subkernel.copy(name=name_in_target)
-
         super(InKernelCallable, self).__init__(name=name,
-                subkernel=subkernel,
                 arg_id_to_dtype=arg_id_to_dtype,
-                arg_id_to_descr=arg_id_to_descr,
-                name_in_target=name_in_target)
+                arg_id_to_descr=arg_id_to_descr)
 
     def with_types(self, arg_id_to_dtype, target):
         """
@@ -233,20 +227,29 @@ class InKernelCallable(ImmutableRecord):
 
     # }}}
 
-    def __eq__(self, other):
-        return (self.name == other.name
-                and self.arg_id_to_descr == other.arg_id_to_descr
-                and self.arg_id_to_dtype == other.arg_id_to_dtype
-                and self.subkernel == other.subkernel)
+# }}}
 
-    def __hash__(self):
-        return hash((self.name, self.subkernel, self.name_in_target))
 
+# {{{ callables on scalar
 
-# }}}
+class CallableOnScalar(InKernelCallable):
 
+    fields = set(["name", "arg_id_to_dtype", "arg_id_to_descr", "name_in_target"])
+    init_arg_names = ("name", "arg_id_to_dtype", "arg_id_to_descr",
+            "name_in_target")
 
-class CallableOnScalar(InKernelCallable):
+    def __init__(self, name, arg_id_to_dtype=None,
+            arg_id_to_descr=None, name_in_target=None):
+
+        super(InKernelCallable, self).__init__(name=name,
+                arg_id_to_dtype=arg_id_to_dtype,
+                arg_id_to_descr=arg_id_to_descr)
+
+        self.name_in_target = name_in_target
+
+    def __getinitargs__(self):
+        return (self.name, self.arg_id_to_dtype, self.arg_id_to_descr,
+                self.name_in_target)
 
     def with_types(self, arg_id_to_dtype, target):
         if self.arg_id_to_dtype is not None:
@@ -384,9 +387,32 @@ class CallableOnScalar(InKernelCallable):
 
     # }}}
 
+# }}}
+
+
+# {{{ callable kernel
 
 class CallableKernel(InKernelCallable):
 
+    fields = set(["name", "subkernel", "arg_id_to_dtype", "arg_id_to_descr",
+        "name_in_target"])
+    init_arg_names = ("name", "subkernel", "arg_id_to_dtype", "arg_id_to_descr",
+            "name_in_target")
+
+    def __init__(self, name, subkernel, arg_id_to_dtype=None,
+            arg_id_to_descr=None, name_in_target=None):
+
+        super(InKernelCallable, self).__init__(name=name,
+                arg_id_to_dtype=arg_id_to_dtype,
+                arg_id_to_descr=arg_id_to_descr)
+
+        self.name_in_target = name_in_target
+        self.subkernel = subkernel
+
+    def __getinitargs__(self):
+        return (self.name, self.subkernel, self.arg_id_to_dtype,
+                self.arg_id_to_descr, self.name_in_target)
+
     def with_types(self, arg_id_to_dtype, target):
 
         kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel)
@@ -475,12 +501,9 @@ class CallableKernel(InKernelCallable):
     def generate_preambles(self, target):
         """ This would generate the target specific preamble.
         """
+        # Transfer the preambel of the subkernel over here
         raise NotImplementedError()
 
-    def emit_call(self, expression_to_code_mapper, expression, target):
-
-        raise NotImplementedError("emit_call only works on scalar operations")
-
     def emit_call_insn(self, insn, target, expression_to_code_mapper):
 
         assert self.is_ready_for_code_gen()
@@ -524,14 +547,77 @@ class CallableKernel(InKernelCallable):
 
     # }}}
 
-    def __eq__(self, other):
-        return (self.name == other.name
-                and self.arg_id_to_descr == other.arg_id_to_descr
-                and self.arg_id_to_dtype == other.arg_id_to_dtype
-                and self.subkernel == other.subkernel)
+# }}}
+
+
+
+
+
+
+class ReductionCallable(InKernelCallable):
+
+    fields = set(["name", "operation", "arg_id_to_dtype", "arg_id_to_descr"])
+    init_arg_names = ("name", "operation", "arg_id_to_dtype", "arg_id_to_descr")
+
+    def __init__(self, name, operation, arg_id_to_dtype=None,
+            arg_id_to_descr=None, name_in_target=None):
+
+        super(InKernelCallable, self).__init__(name=name,
+                arg_id_to_dtype=arg_id_to_dtype,
+                arg_id_to_descr=arg_id_to_descr)
+
+        self.operation = operation
+
+    def with_types(self, arg_id_to_dtype, target):
+        if self.arg_id_to_dtype is not None:
+
+            # specializing an already specialized function.
+
+            for id, dtype in arg_id_to_dtype.items():
+                # only checking for the ones which have been provided
+                if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]:
+                    raise LoopyError("Overwriting a specialized"
+                            " function is illegal--maybe start with new instance of"
+                            " CallableScalar?")
+
+        if self.name in target.get_device_ast_builder().function_identifiers():
+            new_in_knl_callable = target.get_device_ast_builder().with_types(
+                    self, arg_id_to_dtype)
+            if new_in_knl_callable is None:
+                new_in_knl_callable = self.copy()
+            return new_in_knl_callable
+
+        # did not find a scalar function and function prototype does not
+        # even have  subkernel registered => no match found
+        raise LoopyError("Function %s not present within"
+                " the %s namespace" % (self.name, target))
+
+    def with_descrs(self, arg_id_to_descr):
+
+        # This is a scalar call
+        # need to assert that the name is in funtion indentifiers
+        arg_id_to_descr[-1] = ValueArgDescriptor()
+        return self.copy(arg_id_to_descr=arg_id_to_descr)
+
+    def with_iname_tag_usage(self, unusable, concurrent_shape):
+
+        raise NotImplementedError()
+
+    def is_ready_for_code_gen(self):
+
+        return (self.arg_id_to_dtype is not None and
+                self.arg_id_to_descr is not None and
+                self.name_in_target is not None)
+
+
+
+
+
+
+
+
+
 
-    def __hash__(self):
-        return hash((self.name, self.subkernel, self.name_in_target))
 
 
 # {{{ new pymbolic calls to scoped functions
diff --git a/loopy/kernel/reduction_callable.py b/loopy/kernel/reduction_callable.py
new file mode 100644
index 000000000..1682f7160
--- /dev/null
+++ b/loopy/kernel/reduction_callable.py
@@ -0,0 +1,85 @@
+# Note: this file is just for convenience purposes. This would go back into
+# kernel/function_interface.py.
+# keeping it over here until everythin starts working.
+
+
+from __future__ import division, absolute_import
+
+from loopy.diagnostic import LoopyError
+
+from loopy.kernel.function_interface import (InKernelCallable,
+        ValueArgDescriptor)
+
+
+class CallableReduction(InKernelCallable):
+
+    fields = set(["operation", "arg_id_to_dtype", "arg_id_to_descr"])
+    init_arg_names = ("operation", "arg_id_to_dtype", "arg_id_to_descr")
+
+    def __init__(self, operation, arg_id_to_dtype=None,
+            arg_id_to_descr=None, name_in_target=None):
+
+        if isinstance(operation, str):
+            from loopy.library.reduction import parse_reduction_op
+            operation = parse_reduction_op(operation)
+
+        from loopy.library.reduction import ReductionOperation
+        assert isinstance(operation, ReductionOperation)
+
+        self.operation = operation
+
+        super(InKernelCallable, self).__init__(name="",
+                arg_id_to_dtype=arg_id_to_dtype,
+                arg_id_to_descr=arg_id_to_descr)
+
+    def __getinitargs__(self):
+        return (self.operation, self.arg_id_to_dtype,
+                self.arg_id_to_descr)
+
+    @property
+    def is_tuple_typed(self):
+        return self.operation.arg_count > 1
+
+    def with_types(self, arg_id_to_dtype, target):
+        if self.arg_id_to_dtype is not None:
+
+            # specializing an already specialized function.
+
+            for id, dtype in arg_id_to_dtype.items():
+                # only checking for the ones which have been provided
+                if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]:
+                    raise LoopyError("Overwriting a specialized"
+                            " function is illegal--maybe start with new instance of"
+                            " CallableScalar?")
+
+        if self.name in target.get_device_ast_builder().function_identifiers():
+            new_in_knl_callable = target.get_device_ast_builder().with_types(
+                    self, arg_id_to_dtype)
+            if new_in_knl_callable is None:
+                new_in_knl_callable = self.copy()
+            return new_in_knl_callable
+
+        # did not find a scalar function and function prototype does not
+        # even have  subkernel registered => no match found
+        raise LoopyError("Function %s not present within"
+                " the %s namespace" % (self.name, target))
+
+    def with_descrs(self, arg_id_to_descr):
+
+        # This is a scalar call
+        # need to assert that the name is in funtion indentifiers
+        arg_id_to_descr[-1] = ValueArgDescriptor()
+        return self.copy(arg_id_to_descr=arg_id_to_descr)
+
+    def with_iname_tag_usage(self, unusable, concurrent_shape):
+
+        raise NotImplementedError()
+
+    def is_ready_for_code_gen(self):
+
+        return (self.arg_id_to_dtype is not None and
+                self.arg_id_to_descr is not None and
+                self.name_in_target is not None)
+
+
+# vim: foldmethod=marker
diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index 0e5a093b7..5daa1528a 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -422,6 +422,13 @@ def parse_reduction_op(name):
 # }}}
 
 
+def reduction_function_identifiers():
+    """ Return a :class:`set` of the type of the reduction identifiers that can be
+    encountered in a kernel.
+    """
+    return set(op for op in _REDUCTION_OPS)
+
+
 def reduction_function_mangler(kernel, func_id, arg_dtypes):
     if isinstance(func_id, ArgExtOp):
         from loopy.target.opencl import CTarget
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index bdfe57982..e8e39a24f 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -537,9 +537,11 @@ class Reduction(p.Expression):
     """Represents a reduction operation on :attr:`exprs`
     across :attr:`inames`.
 
-    .. attribute:: operation
+    ..attribute:: operation
 
-        an instance of :class:`loopy.library.reduction.ReductionOperation`
+        an instance of :class:`pymbolic.primitives.Variable` which indicates
+        the reduction callable that the reduction would point to in the dict
+        `kernel.scoped_functions`
 
     .. attribute:: inames
 
@@ -563,6 +565,8 @@ class Reduction(p.Expression):
     init_arg_names = ("operation", "inames", "expr", "allow_simultaneous")
 
     def __init__(self, operation, inames, expr, allow_simultaneous=False):
+        assert isinstance(operation, p.Variable)
+
         if isinstance(inames, str):
             inames = tuple(iname.strip() for iname in inames.split(","))
 
@@ -580,6 +584,8 @@ class Reduction(p.Expression):
 
         inames = tuple(strip_var(iname) for iname in inames)
 
+        """
+        # Removed by KK. In order to move to the new interface
         if isinstance(operation, str):
             from loopy.library.reduction import parse_reduction_op
             operation = parse_reduction_op(operation)
@@ -602,6 +608,7 @@ class Reduction(p.Expression):
                 raise LoopyError("got a tuple argument to a scalar reduction")
             elif isinstance(expr, Reduction) and expr.is_tuple_typed:
                 raise LoopyError("got a tuple typed argument to a scalar reduction")
+        """
 
         self.operation = operation
         self.inames = inames
@@ -622,10 +629,12 @@ class Reduction(p.Expression):
 
     def stringifier(self):
         return StringifyMapper
-
+    """
+    # Removed by KK. In order to move to the new interface
     @property
     def is_tuple_typed(self):
         return self.operation.arg_count > 1
+    """
 
     @property
     @memoize_method
@@ -1139,6 +1148,8 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
     def _parse_reduction(self, operation, inames, red_exprs,
             allow_simultaneous=False):
+        assert isinstance(operation, str)
+        operation = p.Variable(operation)
         if isinstance(inames, p.Variable):
             inames = (inames,)
 
@@ -1161,7 +1172,7 @@ class FunctionToPrimitiveMapper(IdentityMapper):
                 allow_simultaneous=allow_simultaneous)
 
     def map_call(self, expr):
-        from loopy.library.reduction import parse_reduction_op
+        from loopy.library.reduction import reduction_function_identifiers
 
         if not isinstance(expr.function, p.Variable):
             return IdentityMapper.map_call(self, expr)
@@ -1181,18 +1192,22 @@ class FunctionToPrimitiveMapper(IdentityMapper):
             else:
                 raise TypeError("cse takes two arguments")
 
-        elif name in ["reduce", "simul_reduce"]:
-
+        elif name in set(["reduce, simul_reduce"]):
             if len(expr.parameters) >= 3:
                 operation, inames = expr.parameters[:2]
                 red_exprs = expr.parameters[2:]
 
-                operation = parse_reduction_op(str(operation))
-                return self._parse_reduction(operation, inames,
+                return self._parse_reduction(str(operation), inames,
                         tuple(self.rec(red_expr) for red_expr in red_exprs),
                         allow_simultaneous=(name == "simul_reduce"))
             else:
+
                 raise TypeError("invalid 'reduce' calling sequence")
+        elif name in reduction_function_identifiers():
+            # KK -- maybe add a check for the arg count?
+            inames = expr.parameters[0]
+            red_exprs = tuple(self.rec(param) for param in expr.parameters[1:])
+            return self._parse_reduction(name, inames, red_exprs)
 
         elif name == "if":
             if len(expr.parameters) == 3:
@@ -1203,23 +1218,7 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
         else:
             # see if 'name' is an existing reduction op
-
-            operation = parse_reduction_op(name)
-            if operation:
-                # arg_count counts arguments but not inames
-                if len(expr.parameters) != 1 + operation.arg_count:
-                    raise RuntimeError("invalid invocation of "
-                            "reduction operation '%s': expected %d arguments, "
-                            "got %d instead" % (expr.function.name,
-                                                1 + operation.arg_count,
-                                                len(expr.parameters)))
-
-                inames = expr.parameters[0]
-                red_exprs = tuple(self.rec(param) for param in expr.parameters[1:])
-                return self._parse_reduction(operation, inames, red_exprs)
-
-            else:
-                return IdentityMapper.map_call(self, expr)
+            return IdentityMapper.map_call(self, expr)
 
     def map_call_with_kwargs(self, expr):
         for par in expr.kw_parameters.values():
-- 
GitLab