From 87c9fa7e1199d821b61b629b40193623fb3530bf Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 21 Oct 2019 08:03:24 -0500 Subject: [PATCH] miscellaneous minor fixes: --- loopy/__init__.py | 2 +- loopy/codegen/__init__.py | 7 +++++-- loopy/kernel/creation.py | 26 +++++++++++++++++--------- loopy/transform/callable.py | 8 ++++---- test/test_callables.py | 11 ++++------- test/testlib.py | 5 ----- 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index 15a670583..8f21cac56 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -178,7 +178,7 @@ __all__ = [ "ScalarCallable", "CallableKernel", - "CallablesTable", "Program", "make_program", + "Program", "make_program", "KernelArgument", "ValueArg", "ArrayArg", "GlobalArg", "ConstantArg", "ImageArg", diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 0f8028e43..8d5bd14f4 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -199,7 +199,8 @@ class CodeGenerationState(object): .. attribute:: callables_table - An instance of :class:`loopy.CallablesTable`. + A mapping from callable names to instances of + :class:`loopy.kernel.function_interface.InKernelCallable`. .. attribute:: is_entrypoint @@ -699,11 +700,13 @@ def generate_code_v2(program): device_programs = ([device_programs[0].copy( ast=Collection(callee_fdecls+[device_programs[0].ast]))] + device_programs[1:]) - return CodeGenerationResult( + cgr = CodeGenerationResult( host_programs=host_programs, device_programs=device_programs, implemented_data_infos=implemented_data_infos) + return cgr + def generate_code(kernel, device=None): if device is not None: diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index c6081156f..242389384 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1900,14 +1900,18 @@ class SliceToInameReplacer(IdentityMapper): if isinstance(index, Slice): unique_var_name = self.var_name_gen(based_on="i") if expr.aggregate.name in self.knl.arg_dict: - domain_length = self.knl.arg_dict[expr.aggregate.name].shape[i] - elif expr.aggregate.name in self.knl.temporary_variables: - domain_length = self.knl.temporary_variables[ - expr.aggregate.name].shape[i] + shape = self.knl.arg_dict[expr.aggregate.name].shape else: + assert expr.aggregate.name in self.knl.temporary_variables + shape = self.knl.temporary_variables[ + expr.aggregate.name].shape + if shape is None or shape[i] is None: raise LoopyError("Slice notation is only supported for " "variables whose shapes are known at creation time " - "-- maybe add the shape for the sliced argument.") + "-- maybe add the shape for '{}'.".format( + expr.aggregate.name)) + + domain_length = shape[i] start, stop, step = get_slice_params( index, domain_length) subscript_iname_bounds[unique_var_name] = (start, stop, step) @@ -2025,7 +2029,7 @@ def realize_slices_array_inputs_as_sub_array_refs(kernel): # {{{ kernel creation top-level -def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): +def make_function(domains, instructions, kernel_data=["..."], **kwargs): """User-facing kernel creation entrypoint. :arg domains: @@ -2378,9 +2382,13 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): return make_program(knl) -def make_function(*args, **kwargs): - #FIXME: Do we need this anymore?? - return make_kernel(*args, **kwargs) +def make_kernel(*args, **kwargs): + tunit = make_function(*args, **kwargs) + name, = [name for name in tunit.callables_table] + return tunit.with_entrypoints(name) + + +make_kernel.__doc__ = make_function.__doc__ # }}} diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index c96a51778..f2e1bead5 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -42,7 +42,7 @@ from loopy.symbolic import SubArrayRef __doc__ = """ .. currentmodule:: loopy -.. autofunction:: register_function_id_to_in_knl_callable_mapper +.. autofunction:: register_callable .. autofunction:: fuse_translation_units """ @@ -61,16 +61,16 @@ def register_callable(translation_unit, function_identifier, callable_, from loopy.kernel.function_interface import InKernelCallable assert isinstance(callable_, InKernelCallable) - if (function_identifier in translation_unit.callables) and ( + if (function_identifier in translation_unit.callables_table) and ( redefining_not_ok): raise LoopyError("Redifining function identifier not allowed. Set the" " option 'redefining_not_ok=False' to bypass this error.") - callables = translation_unit.copy() + callables = translation_unit.callables_table.copy() callables[function_identifier] = callable_ return translation_unit.copy( - callables=callables) + callables_table=callables) def fuse_translation_units(translation_units, collision_not_ok=True): diff --git a/test/test_callables.py b/test/test_callables.py index 04eeae666..17f9a3c0a 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -41,7 +41,7 @@ def test_register_function_lookup(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) - from testlib import register_log2_lookup + from testlib import Log2Callable x = np.random.rand(10) queue = cl.CommandQueue(ctx) @@ -51,8 +51,7 @@ def test_register_function_lookup(ctx_factory): """ y[i] = log2(x[i]) """) - prog = lp.register_function_id_to_in_knl_callable_mapper(prog, - register_log2_lookup) + prog = lp.register_callable(prog, 'log2', Log2Callable('log2')) evt, (out, ) = prog(queue, x=x) @@ -94,10 +93,8 @@ def test_register_knl(ctx_factory, inline): '...'] ) - knl = lp.register_callable_kernel( - parent_knl, child_knl) - knl = lp.register_callable_kernel( - knl, grandchild_knl) + knl = lp.fuse_translation_units([grandchild_knl, child_knl, parent_knl]) + if inline: knl = lp.inline_callable_kernel(knl, 'linear_combo2') knl = lp.inline_callable_kernel(knl, 'linear_combo1') diff --git a/test/testlib.py b/test/testlib.py index 853e2584a..4f45e69b5 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -171,11 +171,6 @@ class Log2Callable(lp.ScalarCallable): callables_table) -def register_log2_lookup(target, identifier): - if identifier == 'log2': - return Log2Callable(name='log2') - return None - # }}} # vim: foldmethod=marker -- GitLab