From 1e28c40a3cdc8b44ba2b05631e6942cfd79444cf Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 6 Aug 2018 06:18:06 -0500
Subject: [PATCH] passes test_callables

---
 loopy/transform/callable.py             | 96 ++++++++++++++++---------
 loopy/transform/pack_and_unpack_args.py | 36 +++++++++-
 test/test_callables.py                  | 77 ++++++++++----------
 test/testlib.py                         | 13 ++--
 4 files changed, 144 insertions(+), 78 deletions(-)

diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py
index 3549d1b75..f73fb9003 100644
--- a/loopy/transform/callable.py
+++ b/loopy/transform/callable.py
@@ -28,7 +28,6 @@ import islpy as isl
 from pymbolic.primitives import CallWithKwargs
 
 from loopy.kernel import LoopKernel
-from loopy.kernel.function_interface import CallableKernel
 from pytools import ImmutableRecord
 from loopy.diagnostic import LoopyError
 from loopy.kernel.instruction import (CallInstruction, MultiAssignmentBase,
@@ -36,13 +35,13 @@ from loopy.kernel.instruction import (CallInstruction, MultiAssignmentBase,
 from loopy.symbolic import IdentityMapper, SubstitutionMapper, CombineMapper
 from loopy.isl_helpers import simplify_via_aff
 from loopy.kernel.function_interface import (get_kw_pos_association,
-        change_names_of_pymbolic_calls)
+        change_names_of_pymbolic_calls, CallableKernel, ScalarCallable)
 from loopy.program import Program, ResolvedFunctionMarker
 
 __doc__ = """
 .. currentmodule:: loopy
 
-.. autofunction:: register_function_resolver
+.. autofunction:: register_function_id_to_in_knl_callable_mapper
 
 .. autofunction:: register_callable_kernel
 """
@@ -170,31 +169,38 @@ def register_callable_kernel(program, callee_kernel):
         arg.is_output_only])
     expected_num_parameters = len(callee_kernel.args) - expected_num_assignees
     for in_knl_callable in program.program_callables_info.values():
-        caller_kernel = in_knl_callable.subkernel
-        for insn in caller_kernel.instructions:
-            if isinstance(insn, CallInstruction) and (
-                    insn.expression.function.name == callee_kernel.name):
-                if isinstance(insn.expression, CallWithKwargs):
-                    kw_parameters = insn.expression.kw_parameters
+        if isinstance(in_knl_callable, CallableKernel):
+            caller_kernel = in_knl_callable.subkernel
+            for insn in caller_kernel.instructions:
+                if isinstance(insn, CallInstruction) and (
+                        insn.expression.function.name == callee_kernel.name):
+                    if isinstance(insn.expression, CallWithKwargs):
+                        kw_parameters = insn.expression.kw_parameters
+                    else:
+                        kw_parameters = {}
+                    if len(insn.assignees) != expected_num_assignees:
+                        raise LoopyError("The number of arguments with 'out' "
+                                "direction " "in callee kernel %s and the number "
+                                "of assignees in " "instruction %s do not "
+                                "match." % (
+                                    callee_kernel.name, insn.id))
+                    if len(insn.expression.parameters+tuple(
+                            kw_parameters.values())) != expected_num_parameters:
+                        raise LoopyError("The number of expected arguments "
+                                "for the callee kernel %s and the number of "
+                                "parameters in instruction %s do not match."
+                                % (callee_kernel.name, insn.id))
+
+                elif isinstance(insn, (MultiAssignmentBase, CInstruction,
+                        _DataObliviousInstruction)):
+                    pass
                 else:
-                    kw_parameters = {}
-                if len(insn.assignees) != expected_num_assignees:
-                    raise LoopyError("The number of arguments with 'out' direction "
-                            "in callee kernel %s and the number of assignees in "
-                            "instruction %s do not match." % (
-                                callee_kernel.name, insn.id))
-                if len(insn.expression.parameters+tuple(
-                        kw_parameters.values())) != expected_num_parameters:
-                    raise LoopyError("The number of expected arguments "
-                            "for the callee kernel %s and the number of parameters "
-                            "in instruction %s do not match." % (
-                                callee_kernel.name, insn.id))
-
-            elif isinstance(insn, (MultiAssignmentBase, CInstruction,
-                    _DataObliviousInstruction)):
-                pass
-            else:
-                raise NotImplementedError("unknown instruction %s" % type(insn))
+                    raise NotImplementedError("unknown instruction %s" % type(insn))
+        elif isinstance(in_knl_callable, ScalarCallable):
+            pass
+        else:
+            raise NotImplementedError("Unknown callable type %s." %
+                    type(in_knl_callable).__name__)
 
     # }}}
 
@@ -537,12 +543,11 @@ def _inline_single_callable_kernel(caller_kernel, function_name,
                 history_of_identifier = program_callables_info.history[
                         insn.expression.function.name]
 
-                from loopy.kernel.function_interface import CallableKernel
                 if function_name in history_of_identifier:
                     in_knl_callable = program_callables_info[
                             insn.expression.function.name]
                     assert isinstance(in_knl_callable, CallableKernel)
-                    new_caller_kernel = _inline_call_instruction(
+                    caller_kernel = _inline_call_instruction(
                             caller_kernel, in_knl_callable.subkernel, insn)
                     program_callables_info = (
                             program_callables_info.with_deleted_callable(
@@ -557,7 +562,7 @@ def _inline_single_callable_kernel(caller_kernel, function_name,
                     "Unknown instruction type %s"
                     % type(insn).__name__)
 
-    return new_caller_kernel, program_callables_info
+    return caller_kernel, program_callables_info
 
 
 # FIXME This should take a 'within' parameter to be able to only inline
@@ -581,7 +586,7 @@ def inline_callable_kernel(program, function_name):
             caller_kernel, program_callables_info = (
                     _inline_single_callable_kernel(caller_kernel,
                         function_name,
-                        program.program_callables_info))
+                        program_callables_info))
             edited_callable_kernels[func_id] = in_knl_callable.copy(
                     subkernel=caller_kernel)
 
@@ -642,7 +647,8 @@ class DimChanger(IdentityMapper):
         return expr.aggregate.index(tuple(new_indices))
 
 
-def _match_caller_callee_argument_dimension(caller_knl, callee_function_name):
+def _match_caller_callee_argument_dimension_for_single_kernel(
+        caller_knl, callee_function_name):
     """
     Returns a copy of *caller_knl* with the instance of
     :class:`loopy.kernel.function_interface.CallableKernel` addressed by
@@ -722,6 +728,32 @@ def _match_caller_callee_argument_dimension(caller_knl, callee_function_name):
     return change_names_of_pymbolic_calls(caller_knl,
             pymbolic_calls_to_new_callables)
 
+
+def _match_caller_callee_argument_dimension_(program, *args, **kwargs):
+    assert isinstance(program, Program)
+
+    new_resolved_functions = {}
+    for func_id, in_knl_callable in program.program_callables_info.items():
+        if isinstance(in_knl_callable, CallableKernel):
+            new_subkernel = (
+                    _match_caller_callee_argument_dimension_for_single_kernel(
+                        in_knl_callable.subkernel, program.program_callables_info,
+                        *args, **kwargs))
+            in_knl_callable = in_knl_callable.copy(
+                    subkernel=new_subkernel)
+
+        elif isinstance(in_knl_callable, ScalarCallable):
+            pass
+        else:
+            raise NotImplementedError("Unknown type of callable %s." % (
+                type(in_knl_callable).__name__))
+
+        new_resolved_functions[func_id] = in_knl_callable
+
+    new_program_callables_info = program.program_callables_info.copy(
+            resolved_functions=new_resolved_functions)
+    return program.copy(program_callables_info=new_program_callables_info)
+
 # }}}
 
 
diff --git a/loopy/transform/pack_and_unpack_args.py b/loopy/transform/pack_and_unpack_args.py
index 87136d017..734072574 100644
--- a/loopy/transform/pack_and_unpack_args.py
+++ b/loopy/transform/pack_and_unpack_args.py
@@ -24,6 +24,9 @@ THE SOFTWARE.
 
 from loopy.diagnostic import LoopyError
 from loopy.kernel.instruction import CallInstruction
+from loopy.program import Program
+from loopy.kernel import LoopKernel
+from loopy.kernel.function_interface import CallableKernel, ScalarCallable
 from loopy.symbolic import SubArrayRef
 
 __doc__ = """
@@ -33,7 +36,8 @@ __doc__ = """
 """
 
 
-def pack_and_unpack_args_for_call(kernel, call_name, args_to_pack=None,
+def pack_and_unpack_args_for_call_for_single_kernel(kernel,
+        program_callables_info, call_name, args_to_pack=None,
         args_to_unpack=None):
     """
     Returns a a copy of *kernel* with instructions appended to copy the
@@ -50,6 +54,7 @@ def pack_and_unpack_args_for_call(kernel, call_name, args_to_pack=None,
         which must be unpacked. If set *None*, it is interpreted that
         all the array arguments should be unpacked.
     """
+    assert isinstance(kernel, LoopKernel)
     new_domains = []
     new_tmps = kernel.temporary_variables.copy()
     old_insn_to_new_insns = {}
@@ -58,10 +63,10 @@ def pack_and_unpack_args_for_call(kernel, call_name, args_to_pack=None,
         if not isinstance(insn, CallInstruction):
             # pack and unpack call only be done for CallInstructions.
             continue
-        if insn.expression.function.name not in kernel.scoped_functions:
+        if insn.expression.function.name not in program_callables_info:
             continue
 
-        in_knl_callable = kernel.scoped_functions[
+        in_knl_callable = program_callables_info[
                 insn.expression.function.name]
 
         if in_knl_callable.name != call_name:
@@ -314,4 +319,29 @@ def pack_and_unpack_args_for_call(kernel, call_name, args_to_pack=None,
 
     return kernel
 
+
+def pack_and_unpack_args_for_call(program, *args, **kwargs):
+    assert isinstance(program, Program)
+
+    new_resolved_functions = {}
+    for func_id, in_knl_callable in program.program_callables_info.items():
+        if isinstance(in_knl_callable, CallableKernel):
+            new_subkernel = pack_and_unpack_args_for_call_for_single_kernel(
+                    in_knl_callable.subkernel, program.program_callables_info,
+                    *args, **kwargs)
+            in_knl_callable = in_knl_callable.copy(
+                    subkernel=new_subkernel)
+
+        elif isinstance(in_knl_callable, ScalarCallable):
+            pass
+        else:
+            raise NotImplementedError("Unknown type of callable %s." % (
+                type(in_knl_callable).__name__))
+
+        new_resolved_functions[func_id] = in_knl_callable
+
+    new_program_callables_info = program.program_callables_info.copy(
+            resolved_functions=new_resolved_functions)
+    return program.copy(program_callables_info=new_program_callables_info)
+
 # vim: foldmethod=marker
diff --git a/test/test_callables.py b/test/test_callables.py
index 9dce5a84a..f25bbbe6f 100644
--- a/test/test_callables.py
+++ b/test/test_callables.py
@@ -52,7 +52,8 @@ def test_register_function_lookup(ctx_factory):
             """
             y[i] = log2(x[i])
             """)
-    prog = lp.register_function_lookup(prog, register_log2_lookup)
+    prog = lp.register_function_id_to_in_knl_callable_mapper(prog,
+            register_log2_lookup)
 
     evt, (out, ) = prog(queue, x=x)
 
@@ -68,17 +69,17 @@ def test_register_knl(ctx_factory, inline):
     x = np.random.rand(n, n, n, n, n)
     y = np.random.rand(n, n, n, n, n)
 
-    grandchild_knl = lp.make_kernel(
+    grandchild_knl = lp.make_kernel_function(
             "{[i, j]:0<= i, j< 16}",
             """
             c[i, j] = 2*a[i, j] + 3*b[i, j]
-            """)
+            """, name='linear_combo1')
 
-    child_knl = lp.make_kernel(
+    child_knl = lp.make_kernel_function(
             "{[i, j]:0<=i, j < 16}",
             """
             [i, j]: g[i, j] = linear_combo1([i, j]: e[i, j], [i, j]: f[i, j])
-            """)
+            """, name='linear_combo2')
 
     parent_knl = lp.make_kernel(
             "{[i, j, k, l, m]: 0<=i, j, k, l, m<16}",
@@ -97,10 +98,10 @@ def test_register_knl(ctx_factory, inline):
                     shape=(16, 16, 16, 16, 16)), '...'],
             )
 
-    child_knl = lp.register_callable_kernel(
-            child_knl, 'linear_combo1', grandchild_knl)
     knl = lp.register_callable_kernel(
-            parent_knl, 'linear_combo2', child_knl)
+            parent_knl, child_knl)
+    knl = lp.register_callable_kernel(
+            knl, grandchild_knl)
     if inline:
         knl = lp.inline_callable_kernel(knl, 'linear_combo2')
         knl = lp.inline_callable_kernel(knl, 'linear_combo1')
@@ -120,11 +121,11 @@ def test_slices_with_negative_step(ctx_factory, inline):
     x = np.random.rand(n, n, n, n, n)
     y = np.random.rand(n, n, n, n, n)
 
-    child_knl = lp.make_kernel(
+    child_knl = lp.make_kernel_function(
             "{[i, j]:0<=i, j < 16}",
             """
             g[i, j] = 2*e[i, j] + 3*f[i, j]
-            """)
+            """, name="linear_combo")
 
     parent_knl = lp.make_kernel(
             "{[i, k, m]: 0<=i, k, m<16}",
@@ -148,7 +149,7 @@ def test_slices_with_negative_step(ctx_factory, inline):
             )
 
     knl = lp.register_callable_kernel(
-            parent_knl, 'linear_combo', child_knl)
+            parent_knl, child_knl)
     if inline:
         knl = lp.inline_callable_kernel(knl, 'linear_combo')
 
@@ -169,7 +170,7 @@ def test_register_knl_with_call_with_kwargs(ctx_factory, inline):
     b_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float32)
     c_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float64)
 
-    callee_knl = lp.make_kernel(
+    callee_knl = lp.make_kernel_function(
             "{[i, j]:0<=i, j < %d}" % n,
             """
             h[i, j] = 2 * e[i, j] + 3*f[i, j] + 4*g[i, j]
@@ -177,11 +178,8 @@ def test_register_knl_with_call_with_kwargs(ctx_factory, inline):
             p[i, j] = 7 * e[i, j] + 4*f1[i, j] + 2*g[i, j]
             """,
             [
-                lp.GlobalArg('f'),
-                lp.GlobalArg('e'),
-                lp.GlobalArg('h'),
-                lp.GlobalArg('g'),
-                '...'])
+                lp.GlobalArg('f, e, h, g'), '...'],
+            name='linear_combo')
 
     caller_knl = lp.make_kernel(
             "{[i, j, k, l, m]: 0<=i, j, k, l, m<%d}" % n,
@@ -194,7 +192,7 @@ def test_register_knl_with_call_with_kwargs(ctx_factory, inline):
             """)
 
     knl = lp.register_callable_kernel(
-            caller_knl, 'linear_combo', callee_knl)
+            caller_knl, callee_knl)
     if inline:
         knl = lp.inline_callable_kernel(knl, 'linear_combo')
 
@@ -223,11 +221,11 @@ def test_register_knl_with_hw_axes(ctx_factory, inline):
     x_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float64)
     y_dev = cl.clrandom.rand(queue, (n, n, n, n, n), np.float64)
 
-    callee_knl = lp.make_kernel(
+    callee_knl = lp.make_kernel_function(
             "{[i, j]:0<=i, j < 16}",
             """
             g[i, j] = 2*e[i, j] + 3*f[i, j]
-            """)
+            """, name='linear_combo')
 
     callee_knl = lp.split_iname(callee_knl, "i", 4, inner_tag="l.0", outer_tag="g.0")
 
@@ -241,7 +239,7 @@ def test_register_knl_with_hw_axes(ctx_factory, inline):
     caller_knl = lp.split_iname(caller_knl, "i", 4, inner_tag="l.1", outer_tag="g.1")
 
     knl = lp.register_callable_kernel(
-            caller_knl, 'linear_combo', callee_knl)
+            caller_knl, callee_knl)
 
     if inline:
         knl = lp.inline_callable_kernel(knl, 'linear_combo')
@@ -264,23 +262,23 @@ def test_shape_translation_through_sub_array_ref(ctx_factory, inline):
     x2 = cl.clrandom.rand(queue, (6, ), dtype=np.float64)
     x3 = cl.clrandom.rand(queue, (6, 6), dtype=np.float64)
 
-    callee1 = lp.make_kernel(
+    callee1 = lp.make_kernel_function(
             "{[i]: 0<=i<6}",
             """
             a[i] = 2*abs(b[i])
-            """)
+            """, name="callee_fn1")
 
-    callee2 = lp.make_kernel(
+    callee2 = lp.make_kernel_function(
             "{[i, j]: 0<=i<3 and 0 <= j < 2}",
             """
             a[i, j] = 3*b[i, j]
-            """)
+            """, name="callee_fn2")
 
-    callee3 = lp.make_kernel(
+    callee3 = lp.make_kernel_function(
             "{[i]: 0<=i<6}",
             """
             a[i] = 5*b[i]
-            """)
+            """, name="callee_fn3")
 
     knl = lp.make_kernel(
             "{[i, j, k, l]:  0<= i < 6 and 0 <= j < 3 and 0 <= k < 2 and 0<=l<6}",
@@ -290,9 +288,9 @@ def test_shape_translation_through_sub_array_ref(ctx_factory, inline):
             [l]: y3[l, l] = callee_fn3([l]: x3[l, l])
             """)
 
-    knl = lp.register_callable_kernel(knl, 'callee_fn1', callee1)
-    knl = lp.register_callable_kernel(knl, 'callee_fn2', callee2)
-    knl = lp.register_callable_kernel(knl, 'callee_fn3', callee3)
+    knl = lp.register_callable_kernel(knl, callee1)
+    knl = lp.register_callable_kernel(knl, callee2)
+    knl = lp.register_callable_kernel(knl, callee3)
 
     if inline:
         knl = lp.inline_callable_kernel(knl, 'callee_fn1')
@@ -321,7 +319,7 @@ def test_multi_arg_array_call(ctx_factory):
     i = p.Variable("i")
     index = p.Variable("index")
     a_i = p.Subscript(p.Variable("a"), p.Variable("i"))
-    argmin_kernel = lp.make_kernel(
+    argmin_kernel = lp.make_kernel_function(
             "{[i]: 0 <= i < n}",
             [
                 lp.Assignment(id="init2", assignee=index,
@@ -333,7 +331,8 @@ def test_multi_arg_array_call(ctx_factory):
                     depends_on="update"),
                 lp.Assignment(id="update", assignee=acc_i,
                     expression=p.Variable("min")(acc_i, a_i),
-                    depends_on="init1,init2")])
+                    depends_on="init1,init2")],
+            name="custom_argmin")
 
     argmin_kernel = lp.fix_parameters(argmin_kernel, n=n)
 
@@ -346,7 +345,7 @@ def test_multi_arg_array_call(ctx_factory):
     knl = lp.fix_parameters(knl, n=n)
     knl = lp.set_options(knl, return_dict=True)
 
-    knl = lp.register_callable_kernel(knl, "custom_argmin", argmin_kernel)
+    knl = lp.register_callable_kernel(knl, argmin_kernel)
     b = np.random.randn(n)
     evt, out_dict = knl(queue, b=b)
     tol = 1e-15
@@ -363,17 +362,17 @@ def test_packing_unpacking(ctx_factory, inline):
     x1 = cl.clrandom.rand(queue, (3, 2), dtype=np.float64)
     x2 = cl.clrandom.rand(queue, (6, ), dtype=np.float64)
 
-    callee1 = lp.make_kernel(
+    callee1 = lp.make_kernel_function(
             "{[i]: 0<=i<6}",
             """
             a[i] = 2*b[i]
-            """)
+            """, name="callee_fn1")
 
-    callee2 = lp.make_kernel(
+    callee2 = lp.make_kernel_function(
             "{[i, j]: 0<=i<2 and 0 <= j < 3}",
             """
             a[i, j] = 3*b[i, j]
-            """)
+            """, name="callee_fn2")
 
     knl = lp.make_kernel(
             "{[i, j, k]:  0<= i < 3 and 0 <= j < 2 and 0 <= k < 6}",
@@ -382,8 +381,8 @@ def test_packing_unpacking(ctx_factory, inline):
             [k]: y2[k] = callee_fn2([k]: x2[k])
             """)
 
-    knl = lp.register_callable_kernel(knl, 'callee_fn1', callee1)
-    knl = lp.register_callable_kernel(knl, 'callee_fn2', callee2)
+    knl = lp.register_callable_kernel(knl, callee1)
+    knl = lp.register_callable_kernel(knl, callee2)
 
     knl = lp.pack_and_unpack_args_for_call(knl, 'callee_fn1')
     knl = lp.pack_and_unpack_args_for_call(knl, 'callee_fn2')
diff --git a/test/testlib.py b/test/testlib.py
index 106a07aeb..eebc792d0 100644
--- a/test/testlib.py
+++ b/test/testlib.py
@@ -139,12 +139,14 @@ class SeparateTemporariesPreambleTestPreambleGenerator(
 
 class Log2Callable(lp.ScalarCallable):
 
-    def with_types(self, arg_id_to_dtype, kernel):
+    def with_types(self, arg_id_to_dtype, kernel, program_callables_info):
 
         if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None:
             # the types provided aren't mature enough to specialize the
             # callable
-            return self.copy(arg_id_to_dtype=arg_id_to_dtype)
+            return (
+                    self.copy(arg_id_to_dtype=arg_id_to_dtype),
+                    program_callables_info)
 
         dtype = arg_id_to_dtype[0].numpy_dtype
 
@@ -162,8 +164,11 @@ class Log2Callable(lp.ScalarCallable):
                 name_in_target = "log2l"
 
         from loopy.types import NumpyType
-        return self.copy(name_in_target=name_in_target,
-                arg_id_to_dtype={0: NumpyType(dtype), -1: NumpyType(dtype)})
+        return (
+                self.copy(name_in_target=name_in_target,
+                    arg_id_to_dtype={0: NumpyType(dtype), -1:
+                        NumpyType(dtype)}),
+                program_callables_info)
 
 
 def register_log2_lookup(target, identifier):
-- 
GitLab