From e9baab6dbb9ba6f6532e9bce1dd9f2d9101fceb1 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 11 Nov 2011 12:32:24 -0500
Subject: [PATCH] Preserve tags when duplicating reduction inames, plus more
 FEM assembly fixes.

---
 loopy/preprocess.py       |  6 +++++-
 test/test_fem_assembly.py | 26 ++++++++++----------------
 2 files changed, 15 insertions(+), 17 deletions(-)

diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 73c474e31..fec2f3e80 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -84,6 +84,8 @@ def duplicate_reduction_inames(kernel):
     new_domain = kernel.domain
     new_insns = []
 
+    new_iname_to_tag = kernel.iname_to_tag.copy()
+
     for insn in kernel.instructions:
         old_insn_inames = []
         new_insn_inames = []
@@ -96,10 +98,12 @@ def duplicate_reduction_inames(kernel):
         from loopy.isl_helpers import duplicate_axes
         for old, new in zip(old_insn_inames, new_insn_inames):
             new_domain = duplicate_axes(new_domain, [old], [new])
+            new_iname_to_tag[new] = kernel.iname_to_tag[old]
 
     return kernel.copy(
             instructions=new_insns,
-            domain=new_domain)
+            domain=new_domain,
+            iname_to_tag=new_iname_to_tag)
 
 # }}}
 
diff --git a/test/test_fem_assembly.py b/test/test_fem_assembly.py
index 77a5757a0..bae8fcdca 100644
--- a/test/test_fem_assembly.py
+++ b/test/test_fem_assembly.py
@@ -46,7 +46,7 @@ def test_laplacian_stiffness(ctx_factory):
             ],
             name="lapquad", assumptions="Nc>=1")
 
-    #knl = lp.tag_dimensions(knl, dict(ax_c="unr"))
+    knl = lp.tag_dimensions(knl, dict(ax_c="unr"))
     seq_knl = knl
     #print lp.preprocess_kernel(seq_knl)
     #1/0
@@ -97,8 +97,6 @@ def test_laplacian_stiffness_nd(ctx_factory):
     ctx = ctx_factory()
     order = "C"
 
-    1/0 # still encounters a dependency bug
-
     dim = 2
 
     Nq = 40 # num. quadrature points
@@ -111,14 +109,14 @@ def test_laplacian_stiffness_nd(ctx_factory):
     Nc_sym = var("Nc")
 
     knl = lp.make_kernel(ctx.devices[0],
-            "[Nc] -> {[K,i,j,q, ax_a, ax_b, ax_c]: 0<=K<Nc and 0<=i,j<%(Nb)d and 0<=q<%(Nq)d "
-            "and 0<= ax_a, ax_b, ax_c < %(dim)d}"
+            "[Nc] -> {[K,i,j,q, ax_a, ax_b]: 0<=K<Nc and 0<=i,j<%(Nb)d and 0<=q<%(Nq)d "
+            "and 0<= ax_a, ax_b < %(dim)d}"
             % dict(Nb=Nb, Nq=Nq, dim=dim),
             [
-                "dPsi(a, dxi) := sum_float32(@ax_c,"
-                    "  jacInv[ax_c,dxi,K,q] * DPsi[ax_c,a,q])",
+                "dPsi(a, dxi) := sum_float32(@ax_b,"
+                    "  jacInv[ax_b,dxi,K,q] * DPsi[ax_b,a,q])",
                 "A[K, i, j] = sum_float32(q, w[q] * jacDet[K,q] * ("
-                    "sum_float32(ax_b, dPsi(0,ax_b)*dPsi(1,ax_b))))"
+                    "sum_float32(ax_a, dPsi(0,ax_a)*dPsi(1,ax_a))))"
                 ],
             [
             lp.ArrayArg("jacInv", dtype, shape=(dim, dim, Nc_sym, Nq), order=order),
@@ -130,7 +128,7 @@ def test_laplacian_stiffness_nd(ctx_factory):
             ],
             name="lapquad", assumptions="Nc>=1")
 
-    #knl = lp.tag_dimensions(knl, dict(ax_c="unr"))
+    knl = lp.tag_dimensions(knl, dict(ax_b="unr"))
     seq_knl = knl
 
     def variant_1(knl):
@@ -154,19 +152,15 @@ def test_laplacian_stiffness_nd(ctx_factory):
         # no ILP across elements, precompute dPsiTransf
         knl = lp.split_dimension(knl, "K", 16, outer_tag="g.0", slabs=(0,1))
         knl = lp.tag_dimensions(knl, {"i": "l.0", "j": "l.1"})
-        knl = lp.precompute(knl, "dPsi",
-                ["a", "dxi"])
+        knl = lp.precompute(knl, "dPsi", np.float32)
         knl = lp.add_prefetch(knl, "jacInv",
                 ["jacInv_dim_0", "jacInv_dim_1", "K_inner", "q"])
         return knl
 
     for variant in [variant_1, variant_2]:
     #for variant in [variant_3]:
-        kernel_gen = lp.generate_loop_schedules(seq_knl,#variant(knl),
-                loop_priority=["jacInv_dim_0", "jacInv_dim_1"],
-                debug=15)
-        for knl in kernel_gen:
-            pass
+        kernel_gen = lp.generate_loop_schedules(variant(knl),
+                loop_priority=["jacInv_dim_0", "jacInv_dim_1"])
         kernel_gen = lp.check_kernels(kernel_gen, dict(Nc=Nc))
 
         lp.auto_test_vs_seq(seq_knl, ctx, kernel_gen,
-- 
GitLab