From 87c9fa7e1199d821b61b629b40193623fb3530bf Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 21 Oct 2019 08:03:24 -0500
Subject: [PATCH] miscellaneous minor fixes:

---
 loopy/__init__.py           |  2 +-
 loopy/codegen/__init__.py   |  7 +++++--
 loopy/kernel/creation.py    | 26 +++++++++++++++++---------
 loopy/transform/callable.py |  8 ++++----
 test/test_callables.py      | 11 ++++-------
 test/testlib.py             |  5 -----
 6 files changed, 31 insertions(+), 28 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 15a670583..8f21cac56 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -178,7 +178,7 @@ __all__ = [
 
         "ScalarCallable", "CallableKernel",
 
-        "CallablesTable", "Program", "make_program",
+        "Program", "make_program",
 
         "KernelArgument",
         "ValueArg", "ArrayArg", "GlobalArg", "ConstantArg", "ImageArg",
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 0f8028e43..8d5bd14f4 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -199,7 +199,8 @@ class CodeGenerationState(object):
 
     .. attribute:: callables_table
 
-        An instance of :class:`loopy.CallablesTable`.
+        A mapping from callable names to instances of
+        :class:`loopy.kernel.function_interface.InKernelCallable`.
 
     .. attribute:: is_entrypoint
 
@@ -699,11 +700,13 @@ def generate_code_v2(program):
     device_programs = ([device_programs[0].copy(
             ast=Collection(callee_fdecls+[device_programs[0].ast]))] +
             device_programs[1:])
-    return CodeGenerationResult(
+    cgr = CodeGenerationResult(
             host_programs=host_programs,
             device_programs=device_programs,
             implemented_data_infos=implemented_data_infos)
 
+    return cgr
+
 
 def generate_code(kernel, device=None):
     if device is not None:
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index c6081156f..242389384 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -1900,14 +1900,18 @@ class SliceToInameReplacer(IdentityMapper):
             if isinstance(index, Slice):
                 unique_var_name = self.var_name_gen(based_on="i")
                 if expr.aggregate.name in self.knl.arg_dict:
-                    domain_length = self.knl.arg_dict[expr.aggregate.name].shape[i]
-                elif expr.aggregate.name in self.knl.temporary_variables:
-                    domain_length = self.knl.temporary_variables[
-                            expr.aggregate.name].shape[i]
+                    shape = self.knl.arg_dict[expr.aggregate.name].shape
                 else:
+                    assert expr.aggregate.name in self.knl.temporary_variables
+                    shape = self.knl.temporary_variables[
+                            expr.aggregate.name].shape
+                if shape is None or shape[i] is None:
                     raise LoopyError("Slice notation is only supported for "
                             "variables whose shapes are known at creation time "
-                            "-- maybe add the shape for the sliced argument.")
+                            "-- maybe add the shape for '{}'.".format(
+                                expr.aggregate.name))
+
+                domain_length = shape[i]
                 start, stop, step = get_slice_params(
                         index, domain_length)
                 subscript_iname_bounds[unique_var_name] = (start, stop, step)
@@ -2025,7 +2029,7 @@ def realize_slices_array_inputs_as_sub_array_refs(kernel):
 
 # {{{ kernel creation top-level
 
-def make_kernel(domains, instructions, kernel_data=["..."], **kwargs):
+def make_function(domains, instructions, kernel_data=["..."], **kwargs):
     """User-facing kernel creation entrypoint.
 
     :arg domains:
@@ -2378,9 +2382,13 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs):
     return make_program(knl)
 
 
-def make_function(*args, **kwargs):
-    #FIXME: Do we need this anymore??
-    return make_kernel(*args, **kwargs)
+def make_kernel(*args, **kwargs):
+    tunit = make_function(*args, **kwargs)
+    name, = [name for name in tunit.callables_table]
+    return tunit.with_entrypoints(name)
+
+
+make_kernel.__doc__ = make_function.__doc__
 
 # }}}
 
diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py
index c96a51778..f2e1bead5 100644
--- a/loopy/transform/callable.py
+++ b/loopy/transform/callable.py
@@ -42,7 +42,7 @@ from loopy.symbolic import SubArrayRef
 __doc__ = """
 .. currentmodule:: loopy
 
-.. autofunction:: register_function_id_to_in_knl_callable_mapper
+.. autofunction:: register_callable
 
 .. autofunction:: fuse_translation_units
 """
@@ -61,16 +61,16 @@ def register_callable(translation_unit, function_identifier, callable_,
     from loopy.kernel.function_interface import InKernelCallable
     assert isinstance(callable_, InKernelCallable)
 
-    if (function_identifier in translation_unit.callables) and (
+    if (function_identifier in translation_unit.callables_table) and (
             redefining_not_ok):
         raise LoopyError("Redifining function identifier not allowed. Set the"
                 " option 'redefining_not_ok=False' to bypass this error.")
 
-    callables = translation_unit.copy()
+    callables = translation_unit.callables_table.copy()
     callables[function_identifier] = callable_
 
     return translation_unit.copy(
-            callables=callables)
+            callables_table=callables)
 
 
 def fuse_translation_units(translation_units, collision_not_ok=True):
diff --git a/test/test_callables.py b/test/test_callables.py
index 04eeae666..17f9a3c0a 100644
--- a/test/test_callables.py
+++ b/test/test_callables.py
@@ -41,7 +41,7 @@ def test_register_function_lookup(ctx_factory):
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
 
-    from testlib import register_log2_lookup
+    from testlib import Log2Callable
 
     x = np.random.rand(10)
     queue = cl.CommandQueue(ctx)
@@ -51,8 +51,7 @@ def test_register_function_lookup(ctx_factory):
             """
             y[i] = log2(x[i])
             """)
-    prog = lp.register_function_id_to_in_knl_callable_mapper(prog,
-            register_log2_lookup)
+    prog = lp.register_callable(prog, 'log2', Log2Callable('log2'))
 
     evt, (out, ) = prog(queue, x=x)
 
@@ -94,10 +93,8 @@ def test_register_knl(ctx_factory, inline):
                 '...']
             )
 
-    knl = lp.register_callable_kernel(
-            parent_knl, child_knl)
-    knl = lp.register_callable_kernel(
-            knl, grandchild_knl)
+    knl = lp.fuse_translation_units([grandchild_knl, child_knl, parent_knl])
+
     if inline:
         knl = lp.inline_callable_kernel(knl, 'linear_combo2')
         knl = lp.inline_callable_kernel(knl, 'linear_combo1')
diff --git a/test/testlib.py b/test/testlib.py
index 853e2584a..4f45e69b5 100644
--- a/test/testlib.py
+++ b/test/testlib.py
@@ -171,11 +171,6 @@ class Log2Callable(lp.ScalarCallable):
                 callables_table)
 
 
-def register_log2_lookup(target, identifier):
-    if identifier == 'log2':
-        return Log2Callable(name='log2')
-    return None
-
 # }}}
 
 # vim: foldmethod=marker
-- 
GitLab