Skip to content
Snippets Groups Projects
Commit 73b3ab13 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Working version of the distributive law transform

parent dfd763c2
No related branches found
No related tags found
No related merge requests found
......@@ -152,28 +152,106 @@ def fold_constants(kernel):
# {{{ collect_common_factors_on_increment
# thus far undocumented
def collect_common_factors_on_increment(kernel, var_name, is_index_specific=False):
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
# FIXME: Does not understand subst rules for now
if kernel.substitutions:
from loopy.transform.subst import expand_subst
kernel = expand_subst(kernel)
if var_name in kernel.temporary_variables:
var_descr = kernel.temporary_variables[var_name]
elif var_name in kernel.arg_dict:
var_descr = kernel.arg_dict[var_name]
else:
raise NameError("array '%s' was not found" % var_name)
# {{{ check/normalize vary_by_axes
if isinstance(vary_by_axes, str):
vary_by_axes = vary_by_axes.split(",")
from loopy.kernel.array import ArrayBase
if isinstance(var_descr, ArrayBase):
if var_descr.dim_names is not None:
name_to_index = dict(
(name, idx)
for idx, name in enumerate(var_descr.dim_names))
else:
name_to_index = {}
def map_ax_name_to_index(ax):
if isinstance(ax, str):
try:
return name_to_index[ax]
except KeyError:
raise LoopyError("axis name '%s' not understood " % ax)
else:
return ax
vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]
if (
vary_by_axes
and
(min(vary_by_axes) < 0
or
max(vary_by_axes) > var_descr.num_user_axes())):
raise LoopyError("vary_by_axes refers to out-of-bounds axis index")
# }}}
from pymbolic.mapper.substitutor import make_subst_func
from pymbolic.primitives import (Sum, Product, is_zero,
flattened_sum, flattened_product, Subscript, Variable)
from loopy.symbolic import get_dependencies, SubstitutionMapper
from loopy.symbolic import (get_dependencies, SubstitutionMapper,
UnidirectionalUnifier)
# {{{ find common factors
# {{{ common factor key list maintenance
# maps lhs indices (or None for is_index_specific)
common_factors = {}
# list of (index_key, common factors found)
common_factors = []
from loopy.kernel.data import ExpressionInstruction
def find_unifiable_cf_index(index_key):
for i, (key, val) in enumerate(common_factors):
unif = UnidirectionalUnifier(
lhs_mapping_candidates=get_dependencies(key))
unif_result = unif(key, index_key)
if unif_result:
assert len(unif_result) == 1
return i, unif_result[0]
return None, None
def extract_index_key(access_expr):
if isinstance(access_expr, Variable):
return ()
elif isinstance(access_expr, Subscript):
index = access_expr.index_tuple
return tuple(index[ax] for ax in vary_by_axes)
else:
raise ValueError("unexpected type of access_expr")
def is_assignee(insn):
return any(
lhs == var_name
for lhs, sbscript in insn.assignees_and_indices())
def iterate_as(cls, expr):
if isinstance(expr, cls):
for ch in expr.children:
yield ch
else:
yield expr
# }}}
# {{{ find common factors
from loopy.kernel.data import ExpressionInstruction
for insn in kernel.instructions:
if not is_assignee(insn):
continue
......@@ -182,46 +260,70 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
raise LoopyError("'%s' modified by non-expression instruction"
% var_name)
(_, index_key), = insn.assignees_and_indices()
if not is_index_specific:
index_key = None
lhs = insn.assignee
rhs = insn.expression
if is_zero(rhs):
continue
if isinstance(rhs, Sum):
sum_terms = rhs.children
index_key = extract_index_key(lhs)
cf_index, unif_result = find_unifiable_cf_index(index_key)
if cf_index is None:
# {{{ doesn't exist yet
assert unif_result is None
my_common_factors = None
for term in iterate_as(Sum, rhs):
if term == lhs:
continue
for part in iterate_as(Product, term):
if var_name in get_dependencies(part):
raise LoopyError("unexpected dependency on '%s' "
"in RHS of instruction '%s'"
% (var_name, insn.id))
product_parts = set(iterate_as(Product, term))
if my_common_factors is None:
my_common_factors = product_parts
else:
my_common_factors = my_common_factors & product_parts
if my_common_factors is not None:
common_factors.append((index_key, my_common_factors))
# }}}
else:
sum_terms = [rhs]
# {{{ match, filter existing common factors
my_common_factors = common_factors.get(index_key)
_, my_common_factors = common_factors[cf_index]
for term in sum_terms:
if term == lhs:
continue
unif_subst_map = SubstitutionMapper(
make_subst_func(unif_result.lmap))
if isinstance(term, Product):
product_parts = set(term.children)
else:
product_parts = set([term])
for term in iterate_as(Sum, rhs):
if term == lhs:
continue
for part in product_parts:
if var_name in get_dependencies(part):
raise LoopyError("unexpected dependency on '%s' "
"in RHS of instruction '%s'"
% (var_name, insn.id))
for part in iterate_as(Product, term):
if var_name in get_dependencies(part):
raise LoopyError("unexpected dependency on '%s' "
"in RHS of instruction '%s'"
% (var_name, insn.id))
if my_common_factors is None:
my_common_factors = product_parts
else:
my_common_factors = my_common_factors & product_parts
product_parts = set(iterate_as(Product, term))
if my_common_factors is not None:
common_factors[index_key] = my_common_factors
my_common_factors = set(
cf for cf in my_common_factors
if unif_subst_map(cf) in product_parts)
common_factors[cf_index] = (index_key, my_common_factors)
# }}}
# }}}
......@@ -236,9 +338,6 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
(_, index_key), = insn.assignees_and_indices()
if not is_index_specific:
index_key = None
lhs = insn.assignee
rhs = insn.expression
......@@ -246,29 +345,34 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
new_insns.append(insn)
continue
if isinstance(rhs, Sum):
sum_terms = rhs.children
else:
sum_terms = [rhs]
index_key = extract_index_key(lhs)
cf_index, unif_result = find_unifiable_cf_index(index_key)
if cf_index is None:
new_insns.append(insn)
continue
_, my_common_factors = common_factors[cf_index]
unif_subst_map = SubstitutionMapper(
make_subst_func(unif_result.lmap))
mapped_my_common_factors = set(
unif_subst_map(cf)
for cf in my_common_factors)
my_common_factors = common_factors.get(index_key)
new_sum_terms = []
for term in sum_terms:
for term in iterate_as(Sum, rhs):
if term == lhs:
new_sum_terms.append(term)
continue
if isinstance(term, Product):
product_parts = term.children
else:
product_parts = [term]
new_sum_terms.append(
flattened_product([
part
for part in product_parts
if part not in my_common_factors
for part in iterate_as(Product, term)
if part not in mapped_my_common_factors
]))
new_insns.append(
......@@ -280,23 +384,27 @@ def collect_common_factors_on_increment(kernel, var_name, is_index_specific=Fals
def find_substitution(expr):
if isinstance(expr, Subscript):
if is_index_specific:
index_key = expr.index
else:
index_key = None
v = expr.aggregate.name
elif isinstance(expr, Variable):
v = expr.name
else:
return None
return expr
if v != var_name:
return None
return expr
index_key = extract_index_key(expr)
cf_index, unif_result = find_unifiable_cf_index(index_key)
unif_subst_map = SubstitutionMapper(
make_subst_func(unif_result.lmap))
_, my_common_factors = common_factors[cf_index]
my_common_factors = common_factors.get(index_key)
if my_common_factors is not None:
return flattened_product(list(my_common_factors) + [expr])
return flattened_product(
[unif_subst_map(cf) for cf in my_common_factors]
+ [expr])
else:
return expr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment