From 7a368750cb61deffa03da870f5ec397b6373c56e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 29 Oct 2011 14:07:50 -0400
Subject: [PATCH] Give user control over whether reduction inames are
 duplicated.

---
 MEMO                |   3 ++
 loopy/kernel.py     | 103 +++++++++++++++++++++++++++++++++++++++++---
 loopy/preprocess.py |  81 ----------------------------------
 3 files changed, 101 insertions(+), 86 deletions(-)

diff --git a/MEMO b/MEMO
index c13497a09..148999e6a 100644
--- a/MEMO
+++ b/MEMO
@@ -89,6 +89,9 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- Give the user control over which reduction inames are
+  duplicated.
+
 - assert dependencies <= parent_inames in loopy/__init__.py
   -> Yes, this must be the case.
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index db1a326a7..9af083c2e 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -830,6 +830,28 @@ def find_var_base_indices_and_shape_from_inames(domain, inames):
 
 
 
+# {{{ count number of uses of each reduction iname
+
+def count_reduction_iname_uses(insn):
+
+    def count_reduction_iname_uses(expr, rec):
+        rec(expr.expr)
+        for iname in expr.inames:
+            reduction_iname_uses[iname] = (
+                    reduction_iname_uses.get(iname, 0)
+                    + 1)
+
+    from loopy.symbolic import ReductionCallbackMapper
+    cb_mapper = ReductionCallbackMapper(count_reduction_iname_uses)
+
+    reduction_iname_uses = {}
+    cb_mapper(insn.expression)
+
+    return reduction_iname_uses
+
+
+
+
 def make_kernel(*args, **kwargs):
     """Second pass of kernel creation. Think about requests for iname duplication
     and temporary variable declaration received as part of string instructions.
@@ -844,20 +866,68 @@ def make_kernel(*args, **kwargs):
 
     newly_created_vars = set()
 
-    for insn in knl.instructions:
+    # {{{ reduction iname duplication helper function
+
+    def duplicate_reduction_inames(reduction_expr, rec):
+        duplicate_inames = [iname
+                for iname, tag in insn.duplicate_inames_and_tags]
+
+        child = rec(reduction_expr.expr)
+        new_red_inames = []
+        did_something = False
+
+        for iname in reduction_expr.inames:
+            if iname in duplicate_inames:
+                new_iname = knl.make_unique_var_name(iname, newly_created_vars)
+
+                old_insn_inames.append(iname)
+                new_insn_inames.append(new_iname)
+                newly_created_vars.add(new_iname)
+                new_red_inames.append(new_iname)
+                reduction_iname_uses[iname] -= 1
+                did_something = True
+            else:
+                new_red_inames.append(iname)
+
+        if did_something:
+            from loopy.symbolic import SubstitutionMapper
+            from pymbolic.mapper.substitutor import make_subst_func
+            from pymbolic import var
+            subst_dict = dict(
+                    (old_iname, var(new_iname))
+                    for old_iname, new_iname in zip(
+                        reduction_expr.inames, new_red_inames))
+            subst_map = SubstitutionMapper(make_subst_func(subst_dict))
+
+            child = subst_map(child)
 
+        from loopy.symbolic import Reduction
+        return Reduction(
+                operation=reduction_expr.operation,
+                inames=tuple(new_red_inames),
+                expr=child)
+
+    # }}}
+
+    for insn in knl.instructions:
         # {{{ iname duplication
 
         if insn.duplicate_inames_and_tags:
+            # {{{ duplicate non-reduction inames
+
+            reduction_iname_uses = count_reduction_iname_uses(insn)
+
             duplicate_inames = [iname
-                    for iname, tag in insn.duplicate_inames_and_tags]
-            new_iname_tags = [tag for iname, tag in insn.duplicate_inames_and_tags]
+                    for iname, tag in insn.duplicate_inames_and_tags
+                    if iname not in reduction_iname_uses]
+            new_iname_tags = [tag for iname, tag in insn.duplicate_inames_and_tags
+                    if iname not in reduction_iname_uses]
 
             new_inames = [
                     knl.make_unique_var_name(
                         iname,
                         extra_used_vars=
-                        newly_created_vars | set(new_temp_vars.iterkeys()))
+                        newly_created_vars)
                     for iname in duplicate_inames]
 
             for iname, tag in zip(new_inames, new_iname_tags):
@@ -874,10 +944,31 @@ def make_kernel(*args, **kwargs):
                     (old_iname, var(new_iname))
                     for old_iname, new_iname in zip(duplicate_inames, new_inames))
             subst_map = SubstitutionMapper(make_subst_func(old_to_new))
+            new_expression = subst_map(insn.expression)
+
+            # }}}
+
+            # {{{ duplicate reduction inames
+
+            if len(duplicate_inames) < len(insn.duplicate_inames_and_tags):
+                # there must've been requests to duplicate reduction inames
+                old_insn_inames = []
+                new_insn_inames = []
+
+                from loopy.symbolic import ReductionCallbackMapper
+                new_expression = (
+                        ReductionCallbackMapper(duplicate_reduction_inames)
+                        (new_expression))
+
+                from loopy.isl_helpers import duplicate_axes
+                for old, new in zip(old_insn_inames, new_insn_inames):
+                    new_domain = duplicate_axes(new_domain, [old], [new])
+
+            # }}}
 
             insn = insn.copy(
                     assignee=subst_map(insn.assignee),
-                    expression=subst_map(insn.expression),
+                    expression=new_expression,
                     forced_iname_deps=[
                         old_to_new.get(iname, iname) for iname in insn.forced_iname_deps],
                     )
@@ -919,6 +1010,8 @@ def make_kernel(*args, **kwargs):
                     base_indices=base_indices,
                     shape=shape)
 
+            newly_created_vars.add(assignee_name)
+
             insn = insn.copy(temp_var_type=None)
 
         # }}}
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 85ecca37d..6f02fd2af 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -6,86 +6,6 @@ import pyopencl.characterize as cl_char
 
 
 
-# {{{ make reduction variables unique
-
-def make_reduction_variables_unique(kernel):
-    # {{{ count number of uses of each reduction iname
-
-    def count_reduction_iname_uses(expr, rec):
-        rec(expr.expr)
-        for iname in expr.inames:
-            reduction_iname_uses[iname] = (
-                    reduction_iname_uses.get(iname, 0)
-                    + 1)
-
-    from loopy.symbolic import ReductionCallbackMapper
-    cb_mapper = ReductionCallbackMapper(count_reduction_iname_uses)
-
-    reduction_iname_uses = {}
-
-    for insn in kernel.instructions:
-        cb_mapper(insn.expression)
-
-    # }}}
-
-    # {{{ make iname uses in reduction unique
-
-    def ensure_reduction_iname_uniqueness(expr, rec):
-        child = rec(expr.expr)
-        my_created_inames = []
-        new_red_inames = []
-
-        for iname in expr.inames:
-            if reduction_iname_uses[iname] > 1:
-                new_iname = kernel.make_unique_var_name(iname, set(new_inames))
-
-                old_inames.append(iname)
-                new_inames.append(new_iname)
-                my_created_inames.append(new_iname)
-                new_red_inames.append(new_iname)
-                reduction_iname_uses[iname] -= 1
-            else:
-                new_red_inames.append(iname)
-
-        if my_created_inames:
-            from loopy.symbolic import SubstitutionMapper
-            from pymbolic.mapper.substitutor import make_subst_func
-            from pymbolic import var
-            subst_dict = dict(
-                    (old_iname, var(new_iname))
-                    for old_iname, new_iname in zip(expr.inames, my_created_inames))
-            subst_map = SubstitutionMapper(make_subst_func(subst_dict))
-
-            child = subst_map(child)
-
-        from loopy.symbolic import Reduction
-        return Reduction(
-                operation=expr.operation,
-                inames=tuple(new_red_inames),
-                expr=child)
-
-    new_insns = []
-    old_inames = []
-    new_inames = []
-
-    from loopy.symbolic import ReductionCallbackMapper
-    cb_mapper = ReductionCallbackMapper(ensure_reduction_iname_uniqueness)
-
-    new_insns = [
-        insn.copy(expression=cb_mapper(insn.expression))
-        for insn in kernel.instructions]
-
-    domain = kernel.domain
-    from loopy.isl_helpers import duplicate_axes
-    for old, new in zip(old_inames, new_inames):
-        domain = duplicate_axes(domain, [old], [new])
-
-    return kernel.copy(instructions=new_insns, domain=domain)
-
-    # }}}
-
-# }}}
-
 # {{{ rewrite reduction to imperative form
 
 def realize_reduction(kernel):
@@ -545,7 +465,6 @@ def adjust_local_temp_var_storage(kernel):
 
 
 def preprocess_kernel(kernel):
-    kernel = make_reduction_variables_unique(kernel)
     kernel = realize_reduction(kernel)
 
     # {{{ check that all CSEs have been realized
-- 
GitLab