diff --git a/loopy/transform/arithmetic.py b/loopy/transform/arithmetic.py index eb4a1fc66a2d5e79a9bfbd202a9ec35c1d654006..fab133bd359153839a4b494c5cf35a7962930a95 100644 --- a/loopy/transform/arithmetic.py +++ b/loopy/transform/arithmetic.py @@ -152,28 +152,106 @@ def fold_constants(kernel): # {{{ collect_common_factors_on_increment # thus far undocumented -def collect_common_factors_on_increment(kernel, var_name, is_index_specific=False): +def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): # FIXME: Does not understand subst rules for now if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) + if var_name in kernel.temporary_variables: + var_descr = kernel.temporary_variables[var_name] + elif var_name in kernel.arg_dict: + var_descr = kernel.arg_dict[var_name] + else: + raise NameError("array '%s' was not found" % var_name) + + # {{{ check/normalize vary_by_axes + + if isinstance(vary_by_axes, str): + vary_by_axes = vary_by_axes.split(",") + + from loopy.kernel.array import ArrayBase + if isinstance(var_descr, ArrayBase): + if var_descr.dim_names is not None: + name_to_index = dict( + (name, idx) + for idx, name in enumerate(var_descr.dim_names)) + else: + name_to_index = {} + + def map_ax_name_to_index(ax): + if isinstance(ax, str): + try: + return name_to_index[ax] + except KeyError: + raise LoopyError("axis name '%s' not understood " % ax) + else: + return ax + + vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] + + if ( + vary_by_axes + and + (min(vary_by_axes) < 0 + or + max(vary_by_axes) > var_descr.num_user_axes())): + raise LoopyError("vary_by_axes refers to out-of-bounds axis index") + + # }}} + + from pymbolic.mapper.substitutor import make_subst_func from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum, flattened_product, Subscript, Variable) - from loopy.symbolic import get_dependencies, SubstitutionMapper + from loopy.symbolic import (get_dependencies, SubstitutionMapper, + UnidirectionalUnifier) - # {{{ find common factors + # {{{ common factor key list maintenance - # maps lhs indices (or None for is_index_specific) - common_factors = {} + # list of (index_key, common factors found) + common_factors = [] - from loopy.kernel.data import ExpressionInstruction + def find_unifiable_cf_index(index_key): + for i, (key, val) in enumerate(common_factors): + unif = UnidirectionalUnifier( + lhs_mapping_candidates=get_dependencies(key)) + + unif_result = unif(key, index_key) + + if unif_result: + assert len(unif_result) == 1 + return i, unif_result[0] + + return None, None + + def extract_index_key(access_expr): + if isinstance(access_expr, Variable): + return () + + elif isinstance(access_expr, Subscript): + index = access_expr.index_tuple + return tuple(index[ax] for ax in vary_by_axes) + else: + raise ValueError("unexpected type of access_expr") def is_assignee(insn): return any( lhs == var_name for lhs, sbscript in insn.assignees_and_indices()) + def iterate_as(cls, expr): + if isinstance(expr, cls): + for ch in expr.children: + yield ch + else: + yield expr + + # }}} + + # {{{ find common factors + + from loopy.kernel.data import ExpressionInstruction + for insn in kernel.instructions: if not is_assignee(insn): continue @@ -182,46 +260,70 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals raise LoopyError("'%s' modified by non-expression instruction" % var_name) - (_, index_key), = insn.assignees_and_indices() - - if not is_index_specific: - index_key = None - lhs = insn.assignee rhs = insn.expression if is_zero(rhs): continue - if isinstance(rhs, Sum): - sum_terms = rhs.children + index_key = extract_index_key(lhs) + cf_index, unif_result = find_unifiable_cf_index(index_key) + + if cf_index is None: + # {{{ doesn't exist yet + + assert unif_result is None + + my_common_factors = None + + for term in iterate_as(Sum, rhs): + if term == lhs: + continue + + for part in iterate_as(Product, term): + if var_name in get_dependencies(part): + raise LoopyError("unexpected dependency on '%s' " + "in RHS of instruction '%s'" + % (var_name, insn.id)) + + product_parts = set(iterate_as(Product, term)) + + if my_common_factors is None: + my_common_factors = product_parts + else: + my_common_factors = my_common_factors & product_parts + + if my_common_factors is not None: + common_factors.append((index_key, my_common_factors)) + + # }}} else: - sum_terms = [rhs] + # {{{ match, filter existing common factors - my_common_factors = common_factors.get(index_key) + _, my_common_factors = common_factors[cf_index] - for term in sum_terms: - if term == lhs: - continue + unif_subst_map = SubstitutionMapper( + make_subst_func(unif_result.lmap)) - if isinstance(term, Product): - product_parts = set(term.children) - else: - product_parts = set([term]) + for term in iterate_as(Sum, rhs): + if term == lhs: + continue - for part in product_parts: - if var_name in get_dependencies(part): - raise LoopyError("unexpected dependency on '%s' " - "in RHS of instruction '%s'" - % (var_name, insn.id)) + for part in iterate_as(Product, term): + if var_name in get_dependencies(part): + raise LoopyError("unexpected dependency on '%s' " + "in RHS of instruction '%s'" + % (var_name, insn.id)) - if my_common_factors is None: - my_common_factors = product_parts - else: - my_common_factors = my_common_factors & product_parts + product_parts = set(iterate_as(Product, term)) - if my_common_factors is not None: - common_factors[index_key] = my_common_factors + my_common_factors = set( + cf for cf in my_common_factors + if unif_subst_map(cf) in product_parts) + + common_factors[cf_index] = (index_key, my_common_factors) + + # }}} # }}} @@ -236,9 +338,6 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals (_, index_key), = insn.assignees_and_indices() - if not is_index_specific: - index_key = None - lhs = insn.assignee rhs = insn.expression @@ -246,29 +345,34 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals new_insns.append(insn) continue - if isinstance(rhs, Sum): - sum_terms = rhs.children - else: - sum_terms = [rhs] + index_key = extract_index_key(lhs) + cf_index, unif_result = find_unifiable_cf_index(index_key) + + if cf_index is None: + new_insns.append(insn) + continue + + _, my_common_factors = common_factors[cf_index] + + unif_subst_map = SubstitutionMapper( + make_subst_func(unif_result.lmap)) + + mapped_my_common_factors = set( + unif_subst_map(cf) + for cf in my_common_factors) - my_common_factors = common_factors.get(index_key) new_sum_terms = [] - for term in sum_terms: + for term in iterate_as(Sum, rhs): if term == lhs: new_sum_terms.append(term) continue - if isinstance(term, Product): - product_parts = term.children - else: - product_parts = [term] - new_sum_terms.append( flattened_product([ part - for part in product_parts - if part not in my_common_factors + for part in iterate_as(Product, term) + if part not in mapped_my_common_factors ])) new_insns.append( @@ -280,23 +384,27 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals def find_substitution(expr): if isinstance(expr, Subscript): - if is_index_specific: - index_key = expr.index - else: - index_key = None - v = expr.aggregate.name elif isinstance(expr, Variable): v = expr.name else: - return None + return expr if v != var_name: - return None + return expr + + index_key = extract_index_key(expr) + cf_index, unif_result = find_unifiable_cf_index(index_key) + + unif_subst_map = SubstitutionMapper( + make_subst_func(unif_result.lmap)) + + _, my_common_factors = common_factors[cf_index] - my_common_factors = common_factors.get(index_key) if my_common_factors is not None: - return flattened_product(list(my_common_factors) + [expr]) + return flattened_product( + [unif_subst_map(cf) for cf in my_common_factors] + + [expr]) else: return expr