From dbde6bb6eb428d021b8123571a477c65972f876c Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sat, 22 May 2021 18:38:26 -0500 Subject: [PATCH] kernel fusion: removes the restriction that only unresolved translation units could be fused --- loopy/transform/fusion.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index a62ba7147..6e28d9e7b 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -130,9 +130,6 @@ def _merge_values(item_name, val_a, val_b): # {{{ two-kernel fusion def _fuse_two_kernels(kernela, kernelb): - from loopy.kernel import KernelState - if kernela.state != KernelState.INITIAL or kernelb.state != KernelState.INITIAL: - raise LoopyError("can only fuse kernels in INITIAL state") # {{{ fuse domains @@ -333,20 +330,42 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None): # namespace, otherwise the kernel names should be uniquified. # We should also somehow be able to know that callables like "sin"/"cos" # belong to the global namespace and need not be uniquified. + if all(isinstance(kernel, TranslationUnit) for kernel in kernels): - new_kernels = [] + # {{{ sanity checks + for knl in kernels: - kernel_names = [i for i, clbl in - knl.callables_table.items() if isinstance(clbl, - CallableKernel)] - if len(kernel_names) != 1: - raise NotImplementedError("Kernel containing more than one" - " callable kernel, not allowed for now.") - new_kernels.append(knl[kernel_names[0]]) + nkernels = len([i for i, clbl in knl.callables_table.items() + if isinstance(clbl, CallableKernel)]) + if nkernels != 1: + raise NotImplementedError("Translation unit with more than one" + " callable kernel not allowed for now.") + + # }}} + + # {{{ "merge" the callable namespace + + from loopy.transform.callable import rename_callable + loop_kernels_to_be_fused = [] + new_callables = {} - kernels = new_kernels[:] + for t_unit in kernels: + for name in set(t_unit.callables_table) & set(new_callables): + t_unit = rename_callable(t_unit, name) + + for name, clbl in t_unit.callables_table.items(): + if isinstance(clbl, CallableKernel): + loop_kernels_to_be_fused.append(clbl.subkernel) + else: + new_callables[name] = clbl + + # }}} + + kernels = loop_kernels_to_be_fused[:] + else: + assert all(isinstance(knl, LoopKernel) for knl in kernels) + new_callables = {} - assert all(isinstance(knl, LoopKernel) for knl in kernels) kernels = list(kernels) if data_flow is None: @@ -425,7 +444,11 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None): # }}} - from loopy.translation_unit import make_program - return make_program(result).with_entrypoints(result.name) + new_callables[result.name] = CallableKernel(result) + + return TranslationUnit(callables_table=new_callables, + target=result.target, + entrypoints=frozenset([result.name])) + # vim: foldmethod=marker -- GitLab