From 3026768f99390bc6412efac86c0eb9c2b0f9f7b9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 12 Jun 2015 11:41:14 -0500
Subject: [PATCH] Kernel fusion: rename temporaries to make them unique

---
 loopy/fusion.py      | 99 ++++++++++++++++++++++++++++++--------------
 test/test_fortran.py | 21 +++++++---
 2 files changed, 83 insertions(+), 37 deletions(-)

diff --git a/loopy/fusion.py b/loopy/fusion.py
index 306138860..4431c2c7f 100644
--- a/loopy/fusion.py
+++ b/loopy/fusion.py
@@ -29,6 +29,7 @@ import islpy as isl
 from islpy import dim_type
 
 from loopy.diagnostic import LoopyError
+from pymbolic import var
 
 
 def _find_fusable_loop_domain_index(domain, other_domains):
@@ -95,32 +96,6 @@ def _fuse_two_kernels(knla, knlb):
     if knla.state != kernel_state.INITIAL or knlb.state != kernel_state.INITIAL:
         raise LoopyError("can only fuse kernels in INITIAL state")
 
-    # {{{ fuse instructions
-
-    new_instructions = knla.instructions[:]
-    from pytools import UniqueNameGenerator
-    insn_id_gen = UniqueNameGenerator(
-            set([insna.id for insna in new_instructions]))
-
-    knl_b_instructions = []
-    old_b_id_to_new_b_id = {}
-    for insnb in knlb.instructions:
-        old_id = insnb.id
-        new_id = insn_id_gen(old_id)
-        old_b_id_to_new_b_id[old_id] = new_id
-
-        knl_b_instructions.append(
-                insnb.copy(id=new_id))
-
-    for insnb in knl_b_instructions:
-        new_instructions.append(
-                insnb.copy(
-                    insn_deps=frozenset(
-                        old_b_id_to_new_b_id[dep_id]
-                        for dep_id in insnb.insn_deps)))
-
-    # }}}
-
     # {{{ fuse domains
 
     new_domains = knla.domains[:]
@@ -151,12 +126,20 @@ def _fuse_two_kernels(knla, knlb):
 
     # }}}
 
+    vng = knla.get_var_name_generator()
+    b_var_renames = {}
+
     # {{{ fuse args
 
     new_args = knla.args[:]
     for b_arg in knlb.args:
         if b_arg.name not in knla.arg_dict:
-            new_args.append(b_arg)
+            new_arg_name = vng(b_arg.name)
+
+            if new_arg_name != b_arg.name:
+                b_var_renames[b_arg.name] = var(new_arg_name)
+
+            new_args.append(b_arg.copy(name=new_arg_name))
         else:
             if b_arg != knla.arg_dict[b_arg.name]:
                 raise LoopyError(
@@ -165,6 +148,63 @@ def _fuse_two_kernels(knla, knlb):
 
     # }}}
 
+    # {{{ fuse temporaries
+
+    new_temporaries = knla.temporary_variables.copy()
+    for b_name, b_tv in six.iteritems(knlb.temporary_variables):
+        assert b_name == b_tv.name
+
+        new_tv_name = vng(b_name)
+
+        if new_tv_name != b_name:
+            b_var_renames[b_name] = var(new_tv_name)
+
+        assert new_tv_name not in new_temporaries
+        new_temporaries[new_tv_name] = b_tv.copy(name=new_tv_name)
+
+    # }}}
+
+    # {{{ apply renames in kernel b
+
+    from loopy.symbolic import (
+            SubstitutionRuleMappingContext,
+            RuleAwareSubstitutionMapper)
+    from pymbolic.mapper.substitutor import make_subst_func
+
+    srmc = SubstitutionRuleMappingContext(
+            knlb.substitutions, knlb.get_var_name_generator())
+    subst_map = RuleAwareSubstitutionMapper(
+            srmc, make_subst_func(b_var_renames), within=lambda stack: True)
+    knlb = subst_map.map_kernel(knlb)
+
+    # }}}
+
+    # {{{ fuse instructions
+
+    new_instructions = knla.instructions[:]
+    from pytools import UniqueNameGenerator
+    insn_id_gen = UniqueNameGenerator(
+            set([insna.id for insna in new_instructions]))
+
+    knl_b_instructions = []
+    old_b_id_to_new_b_id = {}
+    for insnb in knlb.instructions:
+        old_id = insnb.id
+        new_id = insn_id_gen(old_id)
+        old_b_id_to_new_b_id[old_id] = new_id
+
+        knl_b_instructions.append(
+                insnb.copy(id=new_id))
+
+    for insnb in knl_b_instructions:
+        new_instructions.append(
+                insnb.copy(
+                    insn_deps=frozenset(
+                        old_b_id_to_new_b_id[dep_id]
+                        for dep_id in insnb.insn_deps)))
+
+    # }}}
+
     # {{{ fuse assumptions
 
     assump_a = knla.assumptions
@@ -198,10 +238,7 @@ def _fuse_two_kernels(knla, knlb):
             assumptions=new_assumptions,
             local_sizes=_merge_dicts(
                 "local size", knla.local_sizes, knlb.local_sizes),
-            temporary_variables=_merge_dicts(
-                "temporary variable",
-                knla.temporary_variables,
-                knlb.temporary_variables),
+            temporary_variables=new_temporaries,
             iname_to_tag=_merge_dicts(
                 "iname-to-tag mapping",
                 knla.iname_to_tag,
diff --git a/test/test_fortran.py b/test/test_fortran.py
index 124035e52..68bb94cbd 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -372,12 +372,13 @@ def test_fuse_kernels(ctx_factory):
           real*8 result(nelements, ndofs, ndofs)
           real*8 q(nelements, ndofs, ndofs)
           real*8 d(ndofs, ndofs)
+          real*8 prev
 
           do e = 1,nelements
             do i = 1,ndofs
               do j = 1,ndofs
                 do k = 1,ndofs
-                  {line}
+                  {inner}
                 end do
               end do
             end do
@@ -385,20 +386,28 @@ def test_fuse_kernels(ctx_factory):
         end subroutine
         """
 
-    xd_line = "result(e,i,j) = result(e,i,j) + d(i,k)*q(e,i,k)"
-    yd_line = "result(e,i,j) = result(e,i,j) + d(i,k)*q(e,k,j)"
+    xd_line = """
+        prev = result(e,i,j)
+        result(e,i,j) = prev + d(i,k)*q(e,i,k)
+        """
+    yd_line = """
+        prev = result(e,i,j)
+        result(e,i,j) = prev + d(i,k)*q(e,k,j)
+        """
 
     xderiv, = lp.parse_fortran(
-            fortran_template.format(line=xd_line, name="xderiv"))
+            fortran_template.format(inner=xd_line, name="xderiv"))
     yderiv, = lp.parse_fortran(
-            fortran_template.format(line=yd_line, name="yderiv"))
+            fortran_template.format(inner=yd_line, name="yderiv"))
     xyderiv, = lp.parse_fortran(
             fortran_template.format(
-                line=(xd_line + "\n" + yd_line), name="xyderiv"))
+                inner=(xd_line + "\n" + yd_line), name="xyderiv"))
 
     knl = lp.fuse_kernels((xderiv, yderiv))
     knl = lp.set_loop_priority(knl, "e,i,j,k")
 
+    assert len(knl.temporary_variables) == 2
+
     ctx = ctx_factory()
     lp.auto_test_vs_ref(xyderiv, ctx, knl, parameters=dict(nelements=20, ndofs=4))
 
-- 
GitLab