From 95d35e85ce64e6dda020cea14710693ee23a7cc3 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 26 May 2016 02:12:39 +0200
Subject: [PATCH] Add un-exposed reduction_arg_to_subst_rule transformation

---
 loopy/transform/data.py | 65 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 65 insertions(+)

diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index 3db96712e..6038bb88e 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -571,4 +571,69 @@ def set_temporary_scope(kernel, temp_var_names, scope):
 
 # }}}
 
+
+# {{{ reduction_arg_to_subst_rule
+
+def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=None):
+    if isinstance(inames, str):
+        inames = [s.strip() for s in inames.split(",")]
+
+    inames_set = frozenset(inames)
+
+    substs = knl.substitutions.copy()
+
+    var_name_gen = knl.get_var_name_generator()
+
+    def map_reduction(expr, rec, nresults=1):
+        if frozenset(expr.inames) != inames_set:
+            return type(expr)(
+                    operation=expr.operation,
+                    inames=expr.inames,
+                    expr=rec(expr.expr),
+                    allow_simultaneous=expr.allow_simultaneous)
+
+        if subst_rule_name is None:
+            subst_rule_prefix = "red_%s_arg" % "_".join(inames)
+            my_subst_rule_name = var_name_gen(subst_rule_prefix)
+        else:
+            my_subst_rule_name = subst_rule_name
+
+        if my_subst_rule_name in substs:
+            raise LoopyError("substitution rule '%s' already exists"
+                    % my_subst_rule_name)
+
+        from loopy.kernel.data import SubstitutionRule
+        substs[my_subst_rule_name] = SubstitutionRule(
+                name=my_subst_rule_name,
+                arguments=tuple(inames),
+                expression=expr.expr)
+
+        from pymbolic import var
+        iname_vars = [var(iname) for iname in inames]
+
+        return type(expr)(
+                operation=expr.operation,
+                inames=expr.inames,
+                expr=var(my_subst_rule_name)(*iname_vars),
+                allow_simultaneous=expr.allow_simultaneous)
+
+    from loopy.symbolic import ReductionCallbackMapper
+    cb_mapper = ReductionCallbackMapper(map_reduction)
+
+    from loopy.kernel.data import MultiAssignmentBase
+
+    new_insns = []
+    for insn in knl.instructions:
+        if not isinstance(insn, MultiAssignmentBase):
+            new_insns.append(insn)
+        else:
+            new_insns.append(insn.copy(expression=cb_mapper(insn.expression)))
+
+    return knl.copy(
+            instructions=new_insns,
+            substitutions=substs)
+
+# }}}
+
+
 # vim: foldmethod=marker
-- 
GitLab