From 971f37599809c30e432c74cf5d3347b96c5bacd0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 29 Feb 2016 01:34:44 -0600
Subject: [PATCH] Add make_reduction_inames_unique

---
 loopy/__init__.py        |   4 +-
 loopy/preprocess.py      |   1 +
 loopy/transform/iname.py | 106 +++++++++++++++++++++++++++++++++++++++
 test/test_loopy.py       |  26 +++++++++-
 4 files changed, 134 insertions(+), 3 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index c71a03fec..5e3ad5085 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -60,7 +60,8 @@ from loopy.transform.iname import (
         split_iname, chunk_iname, join_inames, tag_inames, duplicate_inames,
         rename_iname, link_inames, remove_unused_inames,
         split_reduction_inward, split_reduction_outward,
-        affine_map_inames, find_unused_axis_tag)
+        affine_map_inames, find_unused_axis_tag,
+        make_reduction_inames_unique)
 
 from loopy.transform.instruction import (
         find_instructions, map_instructions,
@@ -144,6 +145,7 @@ __all__ = [
         "rename_iname", "link_inames", "remove_unused_inames",
         "split_reduction_inward", "split_reduction_outward",
         "affine_map_inames", "find_unused_axis_tag",
+        "make_reduction_inames_unique",
 
         "add_prefetch", "change_arg_to_image", "tag_data_axes",
         "set_array_dim_names", "remove_unused_arguments",
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index b70b39092..4c75cfd25 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -97,6 +97,7 @@ def check_reduction_iname_uniqueness(kernel):
                     "(%d of them, to be precise.) "
                     "Since this usage can easily cause loop scheduling "
                     "problems, this is prohibited by default. "
+                    "Use loopy.make_reduction_inames_unique() to fix this. "
                     "If you are sure that this is OK, write the reduction "
                     "as 'simul_reduce(...)' instead of 'reduce(...)'"
                     % (iname, count))
diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py
index 9c882b98d..b42b338a6 100644
--- a/loopy/transform/iname.py
+++ b/loopy/transform/iname.py
@@ -66,6 +66,8 @@ __doc__ = """
 
 .. autofunction:: find_unused_axis_tag
 
+.. autofunction:: make_reduction_inames_unique
+
 """
 
 
@@ -1405,4 +1407,108 @@ def separate_loop_head_tail_slab(kernel, iname, head_it_count, tail_it_count):
 
 # }}}
 
+
+# {{{ make_reduction_inames_unique
+
+class _ReductionInameUniquifier(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, inames, within):
+        super(_ReductionInameUniquifier, self).__init__(rule_mapping_context)
+
+        self.inames = inames
+        self.old_to_new = []
+        self.within = within
+
+        self.iname_to_red_count = {}
+        self.iname_to_nonsimultaneous_red_count = {}
+
+    def map_reduction(self, expr, expn_state):
+        within = self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)
+
+        for iname in expr.inames:
+            self.iname_to_red_count[iname] = (
+                    self.iname_to_red_count.get(iname, 0) + 1)
+            if not expr.allow_simultaneous:
+                self.iname_to_nonsimultaneous_red_count[iname] = (
+                    self.iname_to_nonsimultaneous_red_count.get(iname, 0) + 1)
+
+        if within and not expr.allow_simultaneous:
+            subst_dict = {}
+
+            from pymbolic import var
+
+            new_inames = []
+            for iname in expr.inames:
+                if (
+                        not (self.inames is None or iname in self.inames)
+                        or
+                        self.iname_to_red_count[iname] <= 1):
+                    new_inames.append(iname)
+                    continue
+
+                new_iname = self.rule_mapping_context.make_unique_var_name(iname)
+                subst_dict[iname] = var(new_iname)
+                self.old_to_new.append((iname, new_iname))
+                new_inames.append(new_iname)
+
+            from loopy.symbolic import SubstitutionMapper
+            from pymbolic.mapper.substitutor import make_subst_func
+
+            from loopy.symbolic import Reduction
+            return Reduction(expr.operation, tuple(new_inames),
+                    self.rec(
+                        SubstitutionMapper(make_subst_func(subst_dict))(
+                            expr.expr),
+                        expn_state),
+                    expr.allow_simultaneous)
+        else:
+            return super(_ReductionInameUniquifier, self).map_reduction(
+                    expr, expn_state)
+
+
+def make_reduction_inames_unique(kernel, inames=None, within=None):
+    """
+    :arg inames: if not *None*, only apply to these inames
+    :arg within: a stack match as understood by
+        :func:`loopy.context_matching.parse_stack_match`.
+
+    .. versionadded:: 2016.2
+    """
+
+    name_gen = kernel.get_var_name_generator()
+
+    from loopy.context_matching import parse_stack_match
+    within = parse_stack_match(within)
+
+    # {{{ change kernel
+
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, name_gen)
+    r_uniq = _ReductionInameUniquifier(rule_mapping_context,
+            inames, within=within)
+
+    kernel = rule_mapping_context.finish_kernel(
+            r_uniq.map_kernel(kernel))
+
+    # }}}
+
+    # {{{ duplicate the inames
+
+    for old_iname, new_iname in r_uniq.old_to_new:
+        from loopy.kernel.tools import DomainChanger
+        domch = DomainChanger(kernel, frozenset([old_iname]))
+
+        from loopy.isl_helpers import duplicate_axes
+        kernel = kernel.copy(
+                domains=domch.get_domains_with(
+                    duplicate_axes(domch.domain, [old_iname], [new_iname])))
+
+    # }}}
+
+    return kernel
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index b38bda855..606eec766 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -952,9 +952,31 @@ def test_double_sum(ctx_factory):
                 ],
             assumptions="n>=1")
 
-    cknl = lp.CompiledKernel(ctx, knl)
+    evt, (a, b) = knl(queue, n=n)
+
+    ref = sum(i*j for i in range(n) for j in range(n))
+    assert a.get() == ref
+    assert b.get() == ref
+
+
+def test_double_sum_made_unique(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    n = 20
+
+    knl = lp.make_kernel(
+            "{[i,j]: 0<=i,j<n }",
+            [
+                "a = sum((i,j), i*j)",
+                "b = sum(i, sum(j, i*j))",
+                ],
+            assumptions="n>=1")
+
+    knl = lp.make_reduction_inames_unique(knl)
+    print(knl)
 
-    evt, (a, b) = cknl(queue, n=n)
+    evt, (a, b) = knl(queue, n=n)
 
     ref = sum(i*j for i in range(n) for j in range(n))
     assert a.get() == ref
-- 
GitLab