From dbde6bb6eb428d021b8123571a477c65972f876c Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 22 May 2021 18:38:26 -0500
Subject: [PATCH] kernel fusion: removes the restriction that only unresolved
 translation units could be fused

---
 loopy/transform/fusion.py | 53 ++++++++++++++++++++++++++++-----------
 1 file changed, 38 insertions(+), 15 deletions(-)

diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py
index a62ba7147..6e28d9e7b 100644
--- a/loopy/transform/fusion.py
+++ b/loopy/transform/fusion.py
@@ -130,9 +130,6 @@ def _merge_values(item_name, val_a, val_b):
 # {{{ two-kernel fusion
 
 def _fuse_two_kernels(kernela, kernelb):
-    from loopy.kernel import KernelState
-    if kernela.state != KernelState.INITIAL or kernelb.state != KernelState.INITIAL:
-        raise LoopyError("can only fuse kernels in INITIAL state")
 
     # {{{ fuse domains
 
@@ -333,20 +330,42 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None):
     # namespace, otherwise the kernel names should be uniquified.
     # We should also somehow be able to know that callables like "sin"/"cos"
     # belong to the global namespace and need not be uniquified.
+
     if all(isinstance(kernel, TranslationUnit) for kernel in kernels):
-        new_kernels = []
+        # {{{ sanity checks
+
         for knl in kernels:
-            kernel_names = [i for i, clbl in
-                    knl.callables_table.items() if isinstance(clbl,
-                        CallableKernel)]
-            if len(kernel_names) != 1:
-                raise NotImplementedError("Kernel containing more than one"
-                        " callable kernel, not allowed for now.")
-            new_kernels.append(knl[kernel_names[0]])
+            nkernels = len([i for i, clbl in knl.callables_table.items()
+                            if isinstance(clbl, CallableKernel)])
+            if nkernels != 1:
+                raise NotImplementedError("Translation unit with more than one"
+                        " callable kernel not allowed for now.")
+
+        # }}}
+
+        # {{{ "merge" the callable namespace
+
+        from loopy.transform.callable import rename_callable
+        loop_kernels_to_be_fused = []
+        new_callables = {}
 
-        kernels = new_kernels[:]
+        for t_unit in kernels:
+            for name in set(t_unit.callables_table) & set(new_callables):
+                t_unit = rename_callable(t_unit, name)
+
+            for name, clbl in t_unit.callables_table.items():
+                if isinstance(clbl, CallableKernel):
+                    loop_kernels_to_be_fused.append(clbl.subkernel)
+                else:
+                    new_callables[name] = clbl
+
+        # }}}
+
+        kernels = loop_kernels_to_be_fused[:]
+    else:
+        assert all(isinstance(knl, LoopKernel) for knl in kernels)
+        new_callables = {}
 
-    assert all(isinstance(knl, LoopKernel) for knl in kernels)
     kernels = list(kernels)
 
     if data_flow is None:
@@ -425,7 +444,11 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None):
 
     # }}}
 
-    from loopy.translation_unit import make_program
-    return make_program(result).with_entrypoints(result.name)
+    new_callables[result.name] = CallableKernel(result)
+
+    return TranslationUnit(callables_table=new_callables,
+                           target=result.target,
+                           entrypoints=frozenset([result.name]))
+
 
 # vim: foldmethod=marker
-- 
GitLab