diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 4755ca1774a15480a2c6b255380dd724e47f9042..d1803e9985c233d7802fb6a1f3b2c963cc40188d 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -28,11 +28,13 @@ from six.moves import range, zip import islpy as isl from loopy.symbolic import (get_dependencies, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, - SubstitutionRuleMappingContext) + SubstitutionRuleMappingContext, WalkMapper, IdentityMapper) from loopy.diagnostic import LoopyError from pymbolic.mapper.substitutor import make_subst_func import numpy as np +from pymbolic.primitives import Call, Variable + from pymbolic import var from loopy.transform.array_buffer_map import (ArrayToBufferMap, NoOpArrayToBufferMap, @@ -128,9 +130,9 @@ class RuleInvocationGatherer(RuleAwareIdentityMapper): # }}} - # {{{ replace rule invocation + class RuleInvocationReplacer(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, subst_name, subst_tag, within, access_descriptors, array_base_map, @@ -253,6 +255,109 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): # }}} +# {{{ Mappers for figuring out all the substitutions whose inames need to be +# changed + + +class OtherSubstsInAnExpr(WalkMapper): + other_substs = None + iname_set = None + + def __init__(self, _iname): + self.other_substs = set() + self.iname_set = frozenset([var(iname) for iname in _iname]) + + def update_inames(self, _iname): + self.iname_set = frozenset([var(iname) for iname in _iname]) + + def post_visit(self, expr): + if isinstance(expr, Call): + if self.iname_set & frozenset(expr.parameters): + self.other_substs.add(expr.function.name) + + def reset(self): + self.other_substs.clear() + +# }}} + +# {{{ Mapper and helper functions to change the iname of a substitution + + +class SubstitutionArgChanger(IdentityMapper): + new_iname = None + old_iname = None + + def __init__(self, _old, _new): + self.new_iname = _new + self.old_iname = _old + + def map_variable(self, expr): + if expr.name == self.old_iname: + return Variable(self.new_iname) + return expr + + +def change_iname_for_subst(knl, subst_name, old_iname, new_iname): + subst = knl.substitutions[subst_name] + if old_iname in subst.arguments: + new_arg = tuple([new_iname if k == old_iname else k for k in + subst.arguments]) + + else: + new_arg = subst.argument.copy() + subs_arg_changer = SubstitutionArgChanger(old_iname, new_iname) + new_expr = subs_arg_changer(subst.expression) + new_subst = subst.copy(arguments=new_arg, expression=new_expr) + + return new_subst + +# }}} + +# {{{ function for finding the recursive subtitutions and instructions + + +def recursive_substs_and_insns(kernel, new_compute_deps, new_subst_deps, + substs_in_insn, substs_in_subst, subst_deps, compute_deps): + newer_compute_deps = set() + newer_subst_deps = set() + for dep in new_compute_deps: + # Looking over the substitutions that are present in the + # instructions for which the inames are being changed + for rule in substs_in_insn[dep]: + newer_subst_deps.add(rule) + + for subst in new_subst_deps: + # Looking for substitutions that are present in this substitution + for rule in substs_in_subst[subst]: + newer_subst_deps.add(subst) + + newer_subst_deps -= subst_deps + while newer_subst_deps: + new_subst_deps.update(newer_subst_deps) + newest_subst_deps = set() + for subst in newer_subst_deps: + # Looking for substitutions that are present in these substitutions + for rule in substs_in_subst[subst]: + newest_subst_deps.add(subst) + newer_subst_deps = newest_subst_deps - subst_deps + + for insn in kernel.instructions: + # now we are looking for instruction which have these new substitutions + if new_subst_deps & substs_in_insn[insn.id]: + newer_compute_deps.add(insn.id) + + from loopy.kernel.tools import find_recursive_dependencies + + newer_compute_deps.update(find_recursive_dependencies(kernel, + newer_compute_deps-compute_deps)) + newer_compute_deps -= compute_deps + + return newer_compute_deps, newer_subst_deps + +# }}} + +# {{{ Precompute + def precompute(kernel, subst_use, sweep_inames=[], within=None, storage_axes=None, temporary_name=None, precompute_inames=None, @@ -335,6 +440,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, `` with the direct sweep axes being the slower-varying indices. + .. note:: This version guarantees that the logic of the program remains + unaffected even if there are dependencies between the substitutions, but + this does not ensure that the arguments of the other substitutions remain + same. + Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are eliminated. """ @@ -853,15 +963,73 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # {{{ propagate storage iname subst to dependencies of compute instructions from loopy.kernel.tools import find_recursive_dependencies + prior_storage_axis_names = frozenset(storage_axis_subst_dict) compute_deps = find_recursive_dependencies( kernel, frozenset([compute_insn_id])) + substs_gatherer = OtherSubstsInAnExpr(sweep_inames) + + substs_in_subst = {} + substs_in_insn = {} + + for subst, rule in kernel.substitutions.items(): + substs_gatherer(rule.expression) + substs_in_subst[subst] = substs_gatherer.other_substs + substs_gatherer.reset() + + from loopy.kernel.instruction import (MultiAssignmentBase, + _DataObliviousInstruction) + for insn in kernel.instructions: + if isinstance(insn, MultiAssignmentBase): + substs_gatherer(insn.expression) + substs_in_insn[insn.id] = substs_gatherer.other_substs.copy() + substs_gatherer.reset() + elif isinstance(insn, _DataObliviousInstruction): + substs_in_insn[insn.id] = set() + pass + else: + NotImplementedError("Not implemented for the type of %s" % insn) + + # special treatment for compute_insn + if not isinstance(compute_insn.assignee, Variable): + substs_gatherer.update_inames([ind.name for ind in + compute_insn.assignee.index_tuple]) + substs_gatherer(compute_insn.expression) + substs_in_insn[compute_insn_id] = substs_gatherer.other_substs.copy() + + subst_deps = set() + + # {{{ looking for recursive substitutions and dependency instructions + + new_compute_deps, new_subst_deps = recursive_substs_and_insns(kernel, + compute_deps, subst_deps, substs_in_insn, substs_in_subst, + subst_deps, compute_deps) + subst_deps.update(new_subst_deps) + + # }}} + + # repeat as long as we are finding new instructions whose iname is to be + # changed + while new_compute_deps: + compute_deps.update(new_compute_deps) + new_compute_deps, new_subst_deps = recursive_substs_and_insns(kernel, + new_compute_deps, new_subst_deps, substs_in_insn, substs_in_subst, + subst_deps, compute_deps) + subst_deps.update(new_subst_deps) + + old_substitutions = kernel.substitutions.copy() + for dep in subst_deps: + subst = kernel.substitutions[dep] + for iname in subst.arguments: + old_substitutions[dep] = change_iname_for_subst(kernel, dep, + iname, storage_axis_subst_dict.get(iname, var(iname)).name) + kernel = kernel.copy(substitutions=old_substitutions) + # FIXME: Need to verify that there are no outside dependencies # on compute_deps - prior_storage_axis_names = frozenset(storage_axis_subst_dict) - new_insns = [] + for insn in kernel.instructions: if (insn.id in compute_deps and insn.within_inames & prior_storage_axis_names): @@ -1004,4 +1172,6 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, return kernel +# }}} + # vim: foldmethod=marker diff --git a/test/test_transform.py b/test/test_transform.py index 0e10db362f36b7fc258059c2ec7ed1a344b97212..73b70095a72791dea2eda45a23049efa2f94a2de 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -241,6 +241,50 @@ def test_alias_temporaries(ctx_factory): parameters=dict(n=30)) +def test_precompute_v2(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + a = np.ones((16, 16)) + b = np.ones((16, 16)) + + knl1 = lp.make_kernel( + "{ [i, i_1, j]: 0<=i, i_1, j<16 }", + """ + subst1(j) := sum(i, a[i, j]) + subst2(j) := sum(i_1, 2*subst1(j)*b[i_1, j]) + out[j] = subst2(j) + """ + ) + + knl2 = lp.make_kernel( + "{ [i, i_1, j]: 0<=i, i_1, j<16 }", + """ + subst1(j) := sum(i, a[i, j]) + tmp[j] = 2*subst1(j) + subst2(j) := sum(i_1, tmp[j]*b[i_1, j]) + out[j] = subst2(j) + """, [lp.TemporaryVariable( + "tmp", + dtype=np.float64, + shape=(16,), + scope=lp.auto), '...']) + + pknl1 = lp.precompute(knl1, "subst2", "j") + new_j = pknl1.substitutions["subst1"].arguments[0] + pknl1 = lp.precompute(pknl1, "subst1", new_j) + + pknl2 = lp.precompute(knl2, "subst2", "j") + new_j = pknl2.substitutions["subst1"].arguments[0] + pknl2 = lp.precompute(pknl2, "subst1", new_j) + + evt, (out1, ) = pknl1(queue, a=a, b=b) + evt, (out2, ) = pknl2(queue, a=a, b=b) + + assert np.linalg.norm(out1-16*32*np.ones(16)) < 1e-15 + assert np.linalg.norm(out2-16*32*np.ones(16)) < 1e-15 + + def test_vectorize(ctx_factory): ctx = ctx_factory()