diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index a62ba7147254f1ac04d429a2862acd22a82c8aeb..6e28d9e7b969372a714af78a3b772f0052347e39 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -130,9 +130,6 @@ def _merge_values(item_name, val_a, val_b): # {{{ two-kernel fusion def _fuse_two_kernels(kernela, kernelb): - from loopy.kernel import KernelState - if kernela.state != KernelState.INITIAL or kernelb.state != KernelState.INITIAL: - raise LoopyError("can only fuse kernels in INITIAL state") # {{{ fuse domains @@ -333,20 +330,42 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None): # namespace, otherwise the kernel names should be uniquified. # We should also somehow be able to know that callables like "sin"/"cos" # belong to the global namespace and need not be uniquified. + if all(isinstance(kernel, TranslationUnit) for kernel in kernels): - new_kernels = [] + # {{{ sanity checks + for knl in kernels: - kernel_names = [i for i, clbl in - knl.callables_table.items() if isinstance(clbl, - CallableKernel)] - if len(kernel_names) != 1: - raise NotImplementedError("Kernel containing more than one" - " callable kernel, not allowed for now.") - new_kernels.append(knl[kernel_names[0]]) + nkernels = len([i for i, clbl in knl.callables_table.items() + if isinstance(clbl, CallableKernel)]) + if nkernels != 1: + raise NotImplementedError("Translation unit with more than one" + " callable kernel not allowed for now.") + + # }}} + + # {{{ "merge" the callable namespace + + from loopy.transform.callable import rename_callable + loop_kernels_to_be_fused = [] + new_callables = {} - kernels = new_kernels[:] + for t_unit in kernels: + for name in set(t_unit.callables_table) & set(new_callables): + t_unit = rename_callable(t_unit, name) + + for name, clbl in t_unit.callables_table.items(): + if isinstance(clbl, CallableKernel): + loop_kernels_to_be_fused.append(clbl.subkernel) + else: + new_callables[name] = clbl + + # }}} + + kernels = loop_kernels_to_be_fused[:] + else: + assert all(isinstance(knl, LoopKernel) for knl in kernels) + new_callables = {} - assert all(isinstance(knl, LoopKernel) for knl in kernels) kernels = list(kernels) if data_flow is None: @@ -425,7 +444,11 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None): # }}} - from loopy.translation_unit import make_program - return make_program(result).with_entrypoints(result.name) + new_callables[result.name] = CallableKernel(result) + + return TranslationUnit(callables_table=new_callables, + target=result.target, + entrypoints=frozenset([result.name])) + # vim: foldmethod=marker