diff --git a/doc/reference.rst b/doc/reference.rst index 556437ea44268045dca88482710e8f6e7b49476d..e04d0fa2e1a895d385a67da195d7544e97f2f69a 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -408,7 +408,7 @@ Dealing with Substitution Rules .. autofunction:: extract_subst -.. autofunction:: temporary_to_subst +.. autofunction:: assignment_to_subst .. autofunction:: expand_subst diff --git a/loopy/__init__.py b/loopy/__init__.py index 7850d67e9476732d6c53a5f31e57059c79a48ddb..e8dd4e54a6e3b7ba9701e300c027256a5524edbf 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -54,7 +54,7 @@ from loopy.kernel.tools import ( add_and_infer_dtypes) from loopy.kernel.creation import make_kernel, UniqueName from loopy.library.reduction import register_reduction_parser -from loopy.subst import extract_subst, expand_subst, temporary_to_subst +from loopy.subst import extract_subst, expand_subst, assignment_to_subst from loopy.precompute import precompute from loopy.buffer import buffer_array from loopy.fusion import fuse_kernels @@ -89,7 +89,7 @@ __all__ = [ "register_reduction_parser", - "extract_subst", "expand_subst", "temporary_to_subst", + "extract_subst", "expand_subst", "assignment_to_subst", "precompute", "buffer_array", "fuse_kernels", "split_arg_axis", "find_padding_multiple", "add_padding", diff --git a/loopy/subst.py b/loopy/subst.py index a0a031718962df3053b80058818b2f2a4b88d2c8..a29e950a1f32d660eb10147c8638612078e816aa 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -198,16 +198,16 @@ def extract_subst(kernel, subst_name, template, parameters=()): substitutions=new_substs) -# {{{ temporary_to_subst +# {{{ assignment_to_subst -class TemporaryToSubstChanger(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, temp_name, definition_insn_ids, +class AssignmentToSubstChanger(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, lhs_name, definition_insn_ids, usage_to_definition, extra_arguments, within): self.var_name_gen = rule_mapping_context.make_unique_var_name - super(TemporaryToSubstChanger, self).__init__(rule_mapping_context) + super(AssignmentToSubstChanger, self).__init__(rule_mapping_context) - self.temp_name = temp_name + self.lhs_name = lhs_name self.definition_insn_ids = definition_insn_ids self.usage_to_definition = usage_to_definition @@ -226,28 +226,28 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): try: return self.definition_insn_id_to_subst_name[def_insn_id] except KeyError: - subst_name = self.var_name_gen(self.temp_name+"_subst") + subst_name = self.var_name_gen(self.lhs_name+"_subst") self.definition_insn_id_to_subst_name[def_insn_id] = subst_name return subst_name def map_variable(self, expr, expn_state): - if (expr.name == self.temp_name + if (expr.name == self.lhs_name and expr.name not in expn_state.arg_context): result = self.transform_access(None, expn_state) if result is not None: return result - return super(TemporaryToSubstChanger, self).map_variable( + return super(AssignmentToSubstChanger, self).map_variable( expr, expn_state) def map_subscript(self, expr, expn_state): - if (expr.aggregate.name == self.temp_name + if (expr.aggregate.name == self.lhs_name and expr.aggregate.name not in expn_state.arg_context): result = self.transform_access(expr.index, expn_state) if result is not None: return result - return super(TemporaryToSubstChanger, self).map_subscript( + return super(AssignmentToSubstChanger, self).map_subscript( expr, expn_state) def transform_access(self, index, expn_state): @@ -280,26 +280,29 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): return var(subst_name)(*index) -def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): - """Extract an assignment to a temporary variable +def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, + force_retain_argument=False): + """Extract an assignment (to a temporary variable or an argument) as a :ref:`substituiton-rule`. The temporary may be an array, in which case the array indices will become arguments to the substitution rule. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. + :arg force_retain_argument: If True and if *lhs_name* is an argument, it is + kept even if it is no longer referenced. This operation will change all usage sites - of *temp_name* matched by *within*. If there - are further usage sites of *temp_name*, then - the original assignment to *temp_name* as well + of *lhs_name* matched by *within*. If there + are further usage sites of *lhs_name*, then + the original assignment to *lhs_name* as well as the temporary variable is left in place. """ if isinstance(extra_arguments, str): extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) - # {{{ establish the relevant definition of temp_name for each usage site + # {{{ establish the relevant definition of lhs_name for each usage site dep_kernel = expand_subst(kernel) from loopy.preprocess import add_default_dependencies @@ -313,11 +316,11 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): def_id = set() for dep_id in insn.insn_deps: dep_insn = id_to_insn[dep_id] - if temp_name in dep_insn.write_dependency_names(): - if temp_name in dep_insn.read_dependency_names(): + if lhs_name in dep_insn.write_dependency_names(): + if lhs_name in dep_insn.read_dependency_names(): raise LoopyError("instruction '%s' both reads *and* " "writes '%s'--cannot transcribe to substitution " - "rule" % (dep_id, temp_name)) + "rule" % (dep_id, lhs_name)) def_id.add(dep_id) else: @@ -329,7 +332,7 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): raise LoopyError("more than one write to '%s' found in " "depdendencies of '%s'--definition cannot be resolved " "(writer instructions ids: %s)" - % (temp_name, usage_insn_id, ", ".join(def_id))) + % (lhs_name, usage_insn_id, ", ".join(def_id))) if not def_id: return None @@ -341,20 +344,20 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): usage_to_definition = {} for insn in kernel.instructions: - if temp_name not in insn.read_dependency_names(): + if lhs_name not in insn.read_dependency_names(): continue def_id = get_relevant_definition_insn_id(insn.id) if def_id is None: raise LoopyError("no write to '%s' found in dependency tree " "of '%s'--definition cannot be resolved" - % (temp_name, insn.id)) + % (lhs_name, insn.id)) usage_to_definition[insn.id] = def_id definition_insn_ids = set() for insn in kernel.instructions: - if temp_name in insn.write_dependency_names(): + if lhs_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) # }}} @@ -364,8 +367,8 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) - tts = TemporaryToSubstChanger(rule_mapping_context, - temp_name, definition_insn_ids, + tts = AssignmentToSubstChanger(rule_mapping_context, + lhs_name, definition_insn_ids, usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) @@ -401,13 +404,28 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): # {{{ delete temporary variable if possible + # (copied below if modified) new_temp_vars = kernel.temporary_variables - if not any(six.itervalues(tts.saw_unmatched_usage_sites)): - # All usage sites matched--they're now substitution rules. - # We can get rid of the variable. + new_args = kernel.args - new_temp_vars = new_temp_vars.copy() - del new_temp_vars[temp_name] + if lhs_name in kernel.temporary_variables: + if not any(six.itervalues(tts.saw_unmatched_usage_sites)): + # All usage sites matched--they're now substitution rules. + # We can get rid of the variable. + + new_temp_vars = new_temp_vars.copy() + del new_temp_vars[lhs_name] + + if lhs_name in kernel.arg_dict and not force_retain_argument: + if not any(six.itervalues(tts.saw_unmatched_usage_sites)): + # All usage sites matched--they're now substitution rules. + # We can get rid of the argument + + new_args = new_args[:] + for i in range(len(new_args)): + if new_args[i].name == lhs_name: + del new_args[i] + break # }}} @@ -423,6 +441,7 @@ def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, + args=new_args, ) # }}} diff --git a/test/test_fortran.py b/test/test_fortran.py index c31c370076b681cb0593f38b6a4d92479541b872..212233ebcd895a93b5e89674323b339d98c08e21 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -123,7 +123,7 @@ def test_asterisk_in_shape(ctx_factory): knl(queue, inp=np.array([1, 2, 3.]), n=3) -def test_temporary_to_subst(ctx_factory): +def test_assignment_to_subst(ctx_factory): fortran_src = """ subroutine fill(out, out2, inp, n) implicit none @@ -143,13 +143,13 @@ def test_temporary_to_subst(ctx_factory): ref_knl = knl - knl = lp.temporary_to_subst(knl, "a", "i") + knl = lp.assignment_to_subst(knl, "a", "i") ctx = ctx_factory() lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) -def test_temporary_to_subst_two_defs(ctx_factory): +def test_assignment_to_subst_two_defs(ctx_factory): fortran_src = """ subroutine fill(out, out2, inp, n) implicit none @@ -170,13 +170,13 @@ def test_temporary_to_subst_two_defs(ctx_factory): ref_knl = knl - knl = lp.temporary_to_subst(knl, "a") + knl = lp.assignment_to_subst(knl, "a") ctx = ctx_factory() lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) -def test_temporary_to_subst_indices(ctx_factory): +def test_assignment_to_subst_indices(ctx_factory): fortran_src = """ subroutine fill(out, out2, inp, n) implicit none @@ -201,7 +201,7 @@ def test_temporary_to_subst_indices(ctx_factory): ref_knl = knl assert "a" in knl.temporary_variables - knl = lp.temporary_to_subst(knl, "a") + knl = lp.assignment_to_subst(knl, "a") assert "a" not in knl.temporary_variables ctx = ctx_factory() @@ -235,7 +235,7 @@ def test_if(ctx_factory): ref_knl = knl - knl = lp.temporary_to_subst(knl, "a") + knl = lp.assignment_to_subst(knl, "a") ctx = ctx_factory() lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) @@ -409,8 +409,8 @@ def test_fuse_kernels(ctx_factory): assert len(knl.temporary_variables) == 2 # This is needed for correctness, otherwise ordering could foul things up. - knl = lp.temporary_to_subst(knl, "prev") - knl = lp.temporary_to_subst(knl, "prev_0") + knl = lp.assignment_to_subst(knl, "prev") + knl = lp.assignment_to_subst(knl, "prev_0") ctx = ctx_factory() lp.auto_test_vs_ref(xyderiv, ctx, knl, parameters=dict(nelements=20, ndofs=4))