diff --git a/loopy/transform/buffer.py b/loopy/transform/buffer.py index b848a6f98abdc91ea6d9fd8052cae7de035cdfb3..57c4397f998a74222bc482127076dee809de2bac 100644 --- a/loopy/transform/buffer.py +++ b/loopy/transform/buffer.py @@ -245,7 +245,7 @@ def buffer_array_for_single_kernel(kernel, program_callables_info, var_name, from loopy.preprocess import prepare_for_caching key_kernel = prepare_for_caching(kernel) - cache_key = (key_kernel, program_callables_info, var_name, + cache_key = (key_kernel, var_name, tuple(buffer_inames), PymbolicExpressionHashWrapper(init_expression), PymbolicExpressionHashWrapper(store_expression), within, diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 0d5f2015e293e977ebf29ee06b29bed9f6c20a73..20dc9a99bddf18ffcc7b381275614f24e95f6ed8 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -1088,6 +1088,7 @@ def has_schedulable_iname_nesting(knl): # {{{ rename_inames +@iterate_over_kernels_if_given_program def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): """ :arg within: a stack match as understood by diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index 0dbc7939e681d45c0b292ed49b1eeb244c97f82e..6a93e0bd99bc779f66fa3fb0aea67b55ea246740 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -35,6 +35,7 @@ from pymbolic import var from loopy.program import iterate_over_kernels_if_given_program from loopy.kernel import LoopKernel +from loopy.kernel.function_interface import CallableKernel, ScalarCallable import logging logger = logging.getLogger(__name__) @@ -508,8 +509,17 @@ def find_rules_matching(knl, pattern): return [r for r in knl.substitutions if pattern.match(r)] -def find_one_rule_matching(knl, pattern): - rules = find_rules_matching(knl, pattern) +def find_one_rule_matching(program, pattern): + rules = [] + for in_knl_callable in program.program_callables_info.values(): + if isinstance(in_knl_callable, CallableKernel): + knl = in_knl_callable.subkernel + rules.extend(find_rules_matching(knl, pattern)) + elif isinstance(in_knl_callable, ScalarCallable): + pass + else: + raise NotImplementedError("Unknown callable types %s." % ( + type(in_knl_callable).__name__)) if len(rules) > 1: raise ValueError("more than one substitution rule matched '%s'" diff --git a/test/test_fortran.py b/test/test_fortran.py index deca4d42e6ffcfc10292039bb17449ffb141f112..1a5a0c383615e2ca5a59b34fc1aef6300f14d89f 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -410,7 +410,7 @@ def test_fuse_kernels(ctx_factory): knl = lp.fuse_kernels((xderiv, yderiv), data_flow=[("result", 0, 1)]) knl = lp.prioritize_loops(knl, "e,i,j,k") - assert len(knl.temporary_variables) == 2 + assert len(knl.root_kernel.temporary_variables) == 2 ctx = ctx_factory() lp.auto_test_vs_ref(xyderiv, ctx, knl, parameters=dict(nelements=20, ndofs=4)) diff --git a/test/test_numa_diff.py b/test/test_numa_diff.py index 6b578838d99cb5aa28296619fdec6e8a2359ba0b..4f802f8bff3ba00763825bc09dbc6051ff1ac527 100644 --- a/test/test_numa_diff.py +++ b/test/test_numa_diff.py @@ -246,7 +246,9 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa "-cl-no-signed-zeros", ]) - hsv = hsv.copy(name="horizontalStrongVolumeKernel") + # FIXME: renaming's a bit tricky in this program model. + # add a simple transformation for it + # hsv = hsv.copy(name="horizontalStrongVolumeKernel") results = lp.auto_test_vs_ref(ref_hsv, ctx, hsv, parameters=dict(elements=300), quiet=True)