From 4412cab060580a46a864367c839160f5f172a392 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 30 Nov 2015 23:17:07 -0600 Subject: [PATCH] Initial version of distributive law transform --- loopy/__init__.py | 5 +- loopy/transform/arithmetic.py | 171 ++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 2 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index 924769091..66fc351ee 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -80,7 +80,8 @@ from loopy.transform.fusion import fuse_kernels from loopy.transform.arithmetic import ( split_reduction_inward, - split_reduction_outward, fold_constants) + split_reduction_outward, fold_constants, + collect_common_factors_on_increment) from loopy.transform.padding import ( split_array_dim, split_arg_axis, find_padding_multiple, @@ -145,7 +146,7 @@ __all__ = [ "fuse_kernels", "split_reduction_inward", "split_reduction_outward", - "fold_constants", + "fold_constants", "collect_common_factors_on_increment", "split_array_dim", "split_arg_axis", "find_padding_multiple", "add_padding", diff --git a/loopy/transform/arithmetic.py b/loopy/transform/arithmetic.py index 540c36681..eb4a1fc66 100644 --- a/loopy/transform/arithmetic.py +++ b/loopy/transform/arithmetic.py @@ -27,6 +27,7 @@ import six from loopy.symbolic import (RuleAwareIdentityMapper, SubstitutionRuleMappingContext) +from loopy.diagnostic import LoopyError # {{{ split_reduction @@ -148,4 +149,174 @@ def fold_constants(kernel): # }}} +# {{{ collect_common_factors_on_increment + +# thus far undocumented +def collect_common_factors_on_increment(kernel, var_name, is_index_specific=False): + # FIXME: Does not understand subst rules for now + if kernel.substitutions: + from loopy.transform.subst import expand_subst + kernel = expand_subst(kernel) + + from pymbolic.primitives import (Sum, Product, is_zero, + flattened_sum, flattened_product, Subscript, Variable) + from loopy.symbolic import get_dependencies, SubstitutionMapper + + # {{{ find common factors + + # maps lhs indices (or None for is_index_specific) + common_factors = {} + + from loopy.kernel.data import ExpressionInstruction + + def is_assignee(insn): + return any( + lhs == var_name + for lhs, sbscript in insn.assignees_and_indices()) + + for insn in kernel.instructions: + if not is_assignee(insn): + continue + + if not isinstance(insn, ExpressionInstruction): + raise LoopyError("'%s' modified by non-expression instruction" + % var_name) + + (_, index_key), = insn.assignees_and_indices() + + if not is_index_specific: + index_key = None + + lhs = insn.assignee + rhs = insn.expression + + if is_zero(rhs): + continue + + if isinstance(rhs, Sum): + sum_terms = rhs.children + else: + sum_terms = [rhs] + + my_common_factors = common_factors.get(index_key) + + for term in sum_terms: + if term == lhs: + continue + + if isinstance(term, Product): + product_parts = set(term.children) + else: + product_parts = set([term]) + + for part in product_parts: + if var_name in get_dependencies(part): + raise LoopyError("unexpected dependency on '%s' " + "in RHS of instruction '%s'" + % (var_name, insn.id)) + + if my_common_factors is None: + my_common_factors = product_parts + else: + my_common_factors = my_common_factors & product_parts + + if my_common_factors is not None: + common_factors[index_key] = my_common_factors + + # }}} + + # {{{ remove common factors + + new_insns = [] + + for insn in kernel.instructions: + if not isinstance(insn, ExpressionInstruction) or not is_assignee(insn): + new_insns.append(insn) + continue + + (_, index_key), = insn.assignees_and_indices() + + if not is_index_specific: + index_key = None + + lhs = insn.assignee + rhs = insn.expression + + if is_zero(rhs): + new_insns.append(insn) + continue + + if isinstance(rhs, Sum): + sum_terms = rhs.children + else: + sum_terms = [rhs] + + my_common_factors = common_factors.get(index_key) + new_sum_terms = [] + + for term in sum_terms: + if term == lhs: + new_sum_terms.append(term) + continue + + if isinstance(term, Product): + product_parts = term.children + else: + product_parts = [term] + + new_sum_terms.append( + flattened_product([ + part + for part in product_parts + if part not in my_common_factors + ])) + + new_insns.append( + insn.copy(expression=flattened_sum(new_sum_terms))) + + # }}} + + # {{{ substitute common factors into usage sites + + def find_substitution(expr): + if isinstance(expr, Subscript): + if is_index_specific: + index_key = expr.index + else: + index_key = None + + v = expr.aggregate.name + elif isinstance(expr, Variable): + v = expr.name + else: + return None + + if v != var_name: + return None + + my_common_factors = common_factors.get(index_key) + if my_common_factors is not None: + return flattened_product(list(my_common_factors) + [expr]) + else: + return expr + + insns = new_insns + new_insns = [] + + subm = SubstitutionMapper(find_substitution) + + for insn in insns: + if not isinstance(insn, ExpressionInstruction) or is_assignee(insn): + new_insns.append(insn) + continue + + new_insns.append(insn.with_transformed_expressions(subm)) + + # }}} + + return kernel.copy(instructions=new_insns) + +# }}} + + # vim: foldmethod=marker -- GitLab