From d654c56bf7fc2e68bab7c658e2de982e582f051f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sat, 5 Nov 2022 00:01:15 -0500 Subject: [PATCH] Implement pt.inline_calls --- pytato/__init__.py | 3 + pytato/codegen.py | 11 ++-- pytato/target/loopy/codegen.py | 70 ++++++++++++-------- pytato/transform/__init__.py | 8 +++ pytato/transform/calls.py | 114 +++++++++++++++++++++++++++++++++ 5 files changed, 177 insertions(+), 29 deletions(-) create mode 100644 pytato/transform/calls.py diff --git a/pytato/__init__.py b/pytato/__init__.py index 46794a9..0509a8c 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -93,6 +93,7 @@ from pytato.visualization import (get_dot_graph, show_dot_graph, get_dot_graph_from_partition, show_fancy_placeholder_data_flow, ) +from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls import pytato.analysis as analysis import pytato.tags as tags import pytato.function as function @@ -165,6 +166,8 @@ __all__ = ( "DistributedGraphPart", "DistributedGraphPartition", + "tag_all_calls_to_be_inlined", "inline_calls", + "find_distributed_partition", "number_distributed_tags", diff --git a/pytato/codegen.py b/pytato/codegen.py index d95f96a..0bc85d6 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -267,6 +267,7 @@ class PreprocessResult: def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: """Preprocess a computation for code generation.""" from pytato.transform import copy_dict_of_named_arrays + from pytato.transform.calls import inline_calls check_validity_of_outputs(outputs) @@ -294,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: # }}} - mapper = CodeGenPreprocessor(target) + new_outputs = inline_calls(outputs) + assert isinstance(new_outputs, DictOfNamedArrays) - new_outputs = copy_dict_of_named_arrays(outputs, mapper) + mapper = CodeGenPreprocessor(target) + new_outputs = copy_dict_of_named_arrays(new_outputs, mapper) return PreprocessResult(outputs=new_outputs, - compute_order=tuple(output_order), - bound_arguments=mapper.bound_arguments) + compute_order=tuple(output_order), + bound_arguments=mapper.bound_arguments) # vim: fdm=marker diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index ad7b734..7d8760e 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -46,6 +46,7 @@ from pytato.target.loopy import LoopyPyOpenCLTarget, LoopyTarget, ImplSubstituti from pytato.transform import Mapper from pytato.scalar_expr import ScalarExpression, INT_CLASSES from pytato.codegen import preprocess, normalize_outputs, SymbolicIndex +from pytato.function import Call, NamedCallResult from pytato.loopy import LoopyCall from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed, ImplementationStrategy) @@ -575,6 +576,19 @@ class CodeGenMapper(Mapper): state.update_kernel(kernel) + def map_named_call_result(self, expr: NamedCallResult, + state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + + def map_call(self, expr: Call, state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + # }}} @@ -972,36 +986,30 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], .. note:: - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: - - - Inames that index over an :class:`~pytato.array.Array`'s axis in the - allocation instruction are tagged with the corresponding - :class:`~pytato.array.Axis`'s tags. The caller may choose to not - propagate axis tags of type *axis_tag_t_to_not_propagate*. - - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* - would be copied over to the tags of the corresponding - :class:`loopy.ArrayArg`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. - - Arrays tagged with :class:`pytato.tags.ImplStored` would have their - tags copied over to the tags of corresponding - :class:`loopy.TemporaryVariable`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. + - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: + + - Inames that index over an :class:`~pytato.array.Array`'s axis in the + allocation instruction are tagged with the corresponding + :class:`~pytato.array.Axis`'s tags. The caller may choose to not + propagate axis tags of type *axis_tag_t_to_not_propagate*. + - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* + would be copied over to the tags of the corresponding + :class:`loopy.ArrayArg`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + - Arrays tagged with :class:`pytato.tags.ImplStored` would have their + tags copied over to the tags of corresponding + :class:`loopy.TemporaryVariable`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + + .. warning:: + + Currently only :class:`~pytato.function.Call` nodes that are tagged with + :class:`pytato.tags.InlineCallTag` can be lowered to :mod:`loopy` IR. """ result_is_dict = isinstance(result, (dict, DictOfNamedArrays)) orig_outputs: DictOfNamedArrays = normalize_outputs(result) - # optimization: remove any ImplStored tags on outputs to avoid redundant - # store-load operations (see https://github.com/inducer/pytato/issues/415) - orig_outputs = DictOfNamedArrays( - {name: (output.without_tags(ImplStored(), - verify_existence=False) - if not isinstance(output, - InputArgumentBase) - else output) - for name, output in orig_outputs._data.items()}, - tags=orig_outputs.tags) - del result if cl_device is not None: @@ -1017,6 +1025,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], preproc_result = preprocess(orig_outputs, target) outputs = preproc_result.outputs + # optimization: remove any ImplStored tags on outputs to avoid redundant + # store-load operations (see https://github.com/inducer/pytato/issues/415) + # (This must be done after all the calls have been inlined) + outputs = DictOfNamedArrays( + {name: (output.without_tags(ImplStored(), + verify_existence=False) + if not isinstance(output, + InputArgumentBase) + else output) + for name, output in outputs._data.items()}, + tags=outputs.tags) + compute_order = preproc_result.compute_order if options is None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 1ceb4c4..953cbfa 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -96,6 +96,14 @@ Dict representation of DAGs .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes + +Transforming call sites +----------------------- + +.. automodule:: pytato.transform.calls + +.. currentmodule:: pytato.transform + Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py new file mode 100644 index 0000000..a25dd83 --- /dev/null +++ b/pytato/transform/calls.py @@ -0,0 +1,114 @@ +""" +.. currentmodule:: pytato.transform.calls + +.. autofunction:: inline_calls +.. autofunction:: tag_all_calls_to_be_inlined +""" +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from immutables import Map +from pytato.transform import (ArrayOrNames, CopyMapper) +from pytato.array import (AbstractResultWithNamedArrays, Array, + DictOfNamedArrays, Placeholder) + +from pytato.function import Call, NamedCallResult +from pytato.tags import InlineCallTag + + +# {{{ inlining + +class PlaceholderSubstitutor(CopyMapper): + """ + .. attribute:: substitutions + + A mapping from the placeholder name to the array that it is to be + substituted with. + """ + def __init__(self, substitutions: Map[str, Array]) -> None: + super().__init__() + self.substitutions = substitutions + + def map_placeholder(self, expr: Placeholder) -> Array: + return self.substitutions[expr.name] + + +class Inliner(CopyMapper): + """ + Primary mapper for :func:`inline_calls`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + # inline call sites within the callee. + new_expr = super().map_call(expr) + assert isinstance(new_expr, Call) + + if expr.tags_of_type(InlineCallTag): + substitutor = PlaceholderSubstitutor(expr.bindings) + + return DictOfNamedArrays( + {name: substitutor(ret) + for name, ret in new_expr.function.returns.items()}, + tags=expr.tags + ) + else: + return new_expr + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + new_call = self.rec(expr._container) + assert isinstance(new_call, AbstractResultWithNamedArrays) + return new_call[expr.name] + + +class InlineMarker(CopyMapper): + """ + Primary mapper for :func:`tag_all_calls_to_be_inlined`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return super().map_call(expr).tagged(InlineCallTag()) + + +def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with call sites tagged with + :class:`pytato.tags.InlineCallTag` inlined into the expression graph. + """ + inliner = Inliner() + return inliner(expr) + + +def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with all reachable instances of + :class:`pytato.function.Call` nodes tagged with + :class:`pytato.tags.InlineCallTag`. + + .. note:: + + This routine does NOT inline calls, to inline the calls + use :func:`tag_all_calls_to_be_inlined` on this routine's + output. + """ + return InlineMarker()(expr) + +# }}} + +# vim:foldmethod=marker -- GitLab