From 95d35e85ce64e6dda020cea14710693ee23a7cc3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 26 May 2016 02:12:39 +0200 Subject: [PATCH] Add un-exposed reduction_arg_to_subst_rule transformation --- loopy/transform/data.py | 65 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 3db96712e..6038bb88e 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -571,4 +571,69 @@ def set_temporary_scope(kernel, temp_var_names, scope): # }}} + +# {{{ reduction_arg_to_subst_rule + +def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=None): + if isinstance(inames, str): + inames = [s.strip() for s in inames.split(",")] + + inames_set = frozenset(inames) + + substs = knl.substitutions.copy() + + var_name_gen = knl.get_var_name_generator() + + def map_reduction(expr, rec, nresults=1): + if frozenset(expr.inames) != inames_set: + return type(expr)( + operation=expr.operation, + inames=expr.inames, + expr=rec(expr.expr), + allow_simultaneous=expr.allow_simultaneous) + + if subst_rule_name is None: + subst_rule_prefix = "red_%s_arg" % "_".join(inames) + my_subst_rule_name = var_name_gen(subst_rule_prefix) + else: + my_subst_rule_name = subst_rule_name + + if my_subst_rule_name in substs: + raise LoopyError("substitution rule '%s' already exists" + % my_subst_rule_name) + + from loopy.kernel.data import SubstitutionRule + substs[my_subst_rule_name] = SubstitutionRule( + name=my_subst_rule_name, + arguments=tuple(inames), + expression=expr.expr) + + from pymbolic import var + iname_vars = [var(iname) for iname in inames] + + return type(expr)( + operation=expr.operation, + inames=expr.inames, + expr=var(my_subst_rule_name)(*iname_vars), + allow_simultaneous=expr.allow_simultaneous) + + from loopy.symbolic import ReductionCallbackMapper + cb_mapper = ReductionCallbackMapper(map_reduction) + + from loopy.kernel.data import MultiAssignmentBase + + new_insns = [] + for insn in knl.instructions: + if not isinstance(insn, MultiAssignmentBase): + new_insns.append(insn) + else: + new_insns.append(insn.copy(expression=cb_mapper(insn.expression))) + + return knl.copy( + instructions=new_insns, + substitutions=substs) + +# }}} + + # vim: foldmethod=marker -- GitLab