diff --git a/doc/reference.rst b/doc/reference.rst index db6baee8ab16521027eb3d28cb199d679bb90a98..595fd4b5fa5bac3ccd69cfb1b7a05ebb1c7b75f9 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -315,6 +315,8 @@ C Block Instructions .. autoclass:: CInstruction +.. _substitution-rule: + Substitution Rules ^^^^^^^^^^^^^^^^^^ @@ -389,6 +391,8 @@ Dealing with Substitution Rules .. autofunction:: extract_subst +.. autofunction:: temporary_to_subst + .. autofunction:: expand_subst Caching, Precomputation and Prefetching diff --git a/loopy/__init__.py b/loopy/__init__.py index 6e90433df76ed30183ec15632584b032adee31dc..97cdef6f6e651b9d725b108966a37d16e1f01992 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -53,7 +53,7 @@ from loopy.kernel.tools import ( add_and_infer_dtypes) from loopy.kernel.creation import make_kernel, UniqueName from loopy.library.reduction import register_reduction_parser -from loopy.subst import extract_subst, expand_subst +from loopy.subst import extract_subst, expand_subst, temporary_to_subst from loopy.precompute import precompute from loopy.padding import (split_arg_axis, find_padding_multiple, add_padding) @@ -81,7 +81,7 @@ __all__ = [ "register_reduction_parser", - "extract_subst", "expand_subst", + "extract_subst", "expand_subst", "temporary_to_subst", "precompute", "split_arg_axis", "find_padding_multiple", "add_padding", diff --git a/loopy/subst.py b/loopy/subst.py index afafe3595de72f18b13997d903baf578e4ac9ed5..a9dca6653b1c24cb3079a48b23d4b08422bda83b 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -25,7 +25,10 @@ THE SOFTWARE. """ -from loopy.symbolic import get_dependencies, SubstitutionMapper +from loopy.symbolic import ( + get_dependencies, SubstitutionMapper, + ExpandingIdentityMapper) +from loopy.diagnostic import LoopyError from pymbolic.mapper.substitutor import make_subst_func from pytools import Record @@ -189,6 +192,213 @@ def extract_subst(kernel, subst_name, template, parameters): substitutions=new_substs) +# {{{ temporary_to_subst + +class TemporaryToSubstChanger(ExpandingIdentityMapper): + def __init__(self, kernel, temp_name, definition_insn_ids, + usage_to_definition, within): + self.var_name_gen = kernel.get_var_name_generator() + + super(TemporaryToSubstChanger, self).__init__( + kernel.substitutions, self.var_name_gen) + + self.kernel = kernel + self.temp_name = temp_name + self.definition_insn_ids = definition_insn_ids + self.usage_to_definition = usage_to_definition + + self.within = within + + self.definition_insn_id_to_subst_name = {} + + self.saw_unmatched_usage_sites = {} + for def_id in self.definition_insn_ids: + self.saw_unmatched_usage_sites[def_id] = False + + def get_subst_name(self, def_insn_id): + try: + return self.definition_insn_id_to_subst_name[def_insn_id] + except KeyError: + subst_name = self.var_name_gen(self.temp_name+"_subst") + self.definition_insn_id_to_subst_name[def_insn_id] = subst_name + return subst_name + + def map_variable(self, expr, expn_state): + if expr.name == self.temp_name: + result = self.transform_access(None, expn_state) + if result is not None: + return result + + return super(ExpandingIdentityMapper, self).map_variable( + expr, expn_state) + + def map_subscript(self, expr, expn_state): + if expr.aggregate.name == self.temp_name: + result = self.transform_access(expr.index, expn_state) + if result is not None: + return result + + return super(ExpandingIdentityMapper, self).map_variable( + expr, expn_state) + + def transform_access(self, index, expn_state): + my_insn_id = expn_state.stack[0][0] + + if my_insn_id in self.definition_insn_ids: + return None + + my_def_id = self.usage_to_definition[my_insn_id] + + if not self.within(expn_state.stack): + self.saw_unmatched_usage_sites[my_def_id] = True + return None + + my_insn_id = expn_state.stack[0][0] + + subst_name = self.get_subst_name(my_def_id) + + from pymbolic import var + if index is None: + return var(subst_name) + else: + return var(subst_name)(*index) + + +def temporary_to_subst(kernel, temp_name, within=None): + """Extract an assignment to a temporary variable + as a :ref:`substituion-rule`. The temporary may + + + :arg within: a stack match as understood by + :func:`loopy.context_matching.parse_stack_match`. + + This operation will change all usage sites + of *temp_name* matched by *within*. If there + are further usage sites of *temp_name*, then + the original assignment to *temp_name* as well + as the temporary variable is left in place. + """ + + # {{{ establish the relevant definition of temp_name for each usage site + + dep_kernel = expand_subst(kernel) + from loopy.preprocess import add_default_dependencies + dep_kernel = add_default_dependencies(dep_kernel) + + id_to_insn = dep_kernel.id_to_insn + + def get_relevant_definition_insn_id(usage_insn_id): + insn = id_to_insn[usage_insn_id] + + def_id = set() + for dep_id in insn.insn_deps: + dep_insn = id_to_insn[dep_id] + if temp_name in dep_insn.write_dependency_names(): + if temp_name in dep_insn.read_dependency_names(): + raise LoopyError("instruction '%s' both reads *and* " + "writes '%s'--cannot transcribe to substitution " + "rule" % (dep_id, temp_name)) + + def_id.add(dep_id) + else: + def_id.add(get_relevant_definition_insn_id(dep_id)) + + if len(def_id) > 1: + raise LoopyError("more than one write to '%s' found in " + "depdendencies of '%s'--definition cannot be resolved" + % (temp_name, usage_insn_id)) + + if not def_id: + return None + else: + def_id, = def_id + + return def_id + + usage_to_definition = {} + definition_insn_ids = set() + + for insn in kernel.instructions: + if temp_name not in insn.read_dependency_names(): + continue + + def_id = get_relevant_definition_insn_id(insn.id) + if def_id is None: + raise LoopyError("no write to '%s' found in dependency tree " + "of '%s'--definition cannot be resolved" + % (temp_name, insn.id)) + + usage_to_definition[insn.id] = def_id + definition_insn_ids.add(def_id) + + # }}} + + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + + tts = TemporaryToSubstChanger(kernel, temp_name, definition_insn_ids, + usage_to_definition, within) + + kernel = tts.map_kernel(kernel) + + from loopy.kernel.data import SubstitutionRule + + # {{{ create new substitution rules + + new_substs = kernel.substitutions.copy() + for def_id, subst_name in six.iteritems(tts.definition_insn_id_to_subst_name): + def_insn = id_to_insn[def_id] + + (_, indices), = def_insn.assignees_and_indices() + + arguments = [] + + from pymbolic.primitives import Variable + for i in indices: + if not isinstance(i, Variable): + raise LoopyError("In defining instruction '%s': " + "asignee index '%s' is not a plain variable. " + "Perhaps use loopy.affine_map_inames() " + "to perform substitution." % (def_id, i)) + + arguments.append(i.name) + + new_substs[subst_name] = SubstitutionRule( + name=subst_name, + arguments=tuple(arguments), + expression=def_insn.expression) + + # }}} + + # {{{ delete temporary variable if possible + + new_temp_vars = kernel.temporary_variables + if not any(six.itervalues(tts.saw_unmatched_usage_sites)): + # All usage sites matched--they're now substitution rules. + # We can get rid of the variable. + + new_temp_vars = new_temp_vars.copy() + del new_temp_vars[temp_name] + + # }}} + + import loopy as lp + kernel = lp.remove_instructions( + kernel, + set( + insn_id + for insn_id, still_used in six.iteritems( + tts.saw_unmatched_usage_sites) + if not still_used)) + + return kernel.copy( + substitutions=new_substs, + temporary_variables=new_temp_vars, + ) + +# }}} + + def expand_subst(kernel, ctx_match=None): logger.debug("%s: expand subst" % kernel.name) diff --git a/test/test_fortran.py b/test/test_fortran.py index d887cd40197f7b942f89b3f2df25d5e563458e26..bf6ffe140289c13e5c04df02b0c226f44832d41b 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -1,7 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six -from six.moves import range +from __future__ import division, absolute_import __copyright__ = "Copyright (C) 2015 Andreas Kloeckner" @@ -45,6 +42,9 @@ __all__ = [ ] +pytestmark = pytest.mark.importorskip("fparser") + + def test_fill(ctx_factory): fortran_src = """ subroutine fill(out, a, n) @@ -120,7 +120,91 @@ def test_asterisk_in_shape(ctx_factory): knl(queue, inp=np.array([1, 2, 3.]), n=3) - #lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5)) + +def test_temporary_to_subst(ctx_factory): + fortran_src = """ + subroutine fill(out, out2, inp, n) + implicit none + + real*8 a, out(n), out2(n), inp(n) + integer n + + do i = 1, n + a = inp(n) + out(i) = 5*a + out2(i) = 6*a + end do + end + """ + + from loopy.frontend.fortran import f2loopy + knl, = f2loopy(fortran_src) + + ref_knl = knl + + knl = lp.temporary_to_subst(knl, "a") + + ctx = ctx_factory() + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) + + +def test_temporary_to_subst_two_defs(ctx_factory): + fortran_src = """ + subroutine fill(out, out2, inp, n) + implicit none + + real*8 a, out(n), out2(n), inp(n) + integer n + + do i = 1, n + a = inp(i) + out(i) = 5*a + a = 3*inp(n) + out2(i) = 6*a + end do + end + """ + + from loopy.frontend.fortran import f2loopy + knl, = f2loopy(fortran_src) + + ref_knl = knl + + knl = lp.temporary_to_subst(knl, "a") + + ctx = ctx_factory() + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) + + +def test_temporary_to_subst_indices(ctx_factory): + fortran_src = """ + subroutine fill(out, out2, inp, n) + implicit none + + real*8 a(n), out(n), out2(n), inp(n) + integer n + + do i = 1, n + a(i) = 6*inp(i) + enddo + + do i = 1, n + out(i) = 5*a(i) + end do + end + """ + + from loopy.frontend.fortran import f2loopy + knl, = f2loopy(fortran_src) + + knl = lp.fix_parameters(knl, n=5) + + ref_knl = knl + + knl = lp.temporary_to_subst(knl, "a") + + ctx = ctx_factory() + lp.auto_test_vs_ref(ref_knl, ctx, knl) if __name__ == "__main__":