From d654c56bf7fc2e68bab7c658e2de982e582f051f Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 5 Nov 2022 00:01:15 -0500
Subject: [PATCH] Implement pt.inline_calls

---
 pytato/__init__.py             |   3 +
 pytato/codegen.py              |  11 ++--
 pytato/target/loopy/codegen.py |  70 ++++++++++++--------
 pytato/transform/__init__.py   |   8 +++
 pytato/transform/calls.py      | 114 +++++++++++++++++++++++++++++++++
 5 files changed, 177 insertions(+), 29 deletions(-)
 create mode 100644 pytato/transform/calls.py

diff --git a/pytato/__init__.py b/pytato/__init__.py
index 46794a9..0509a8c 100644
--- a/pytato/__init__.py
+++ b/pytato/__init__.py
@@ -93,6 +93,7 @@ from pytato.visualization import (get_dot_graph, show_dot_graph,
                                   get_dot_graph_from_partition,
                                   show_fancy_placeholder_data_flow,
                                   )
+from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls
 import pytato.analysis as analysis
 import pytato.tags as tags
 import pytato.function as function
@@ -165,6 +166,8 @@ __all__ = (
 
         "DistributedGraphPart",
         "DistributedGraphPartition",
+        "tag_all_calls_to_be_inlined", "inline_calls",
+
         "find_distributed_partition",
 
         "number_distributed_tags",
diff --git a/pytato/codegen.py b/pytato/codegen.py
index d95f96a..0bc85d6 100644
--- a/pytato/codegen.py
+++ b/pytato/codegen.py
@@ -267,6 +267,7 @@ class PreprocessResult:
 def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
     """Preprocess a computation for code generation."""
     from pytato.transform import copy_dict_of_named_arrays
+    from pytato.transform.calls import inline_calls
 
     check_validity_of_outputs(outputs)
 
@@ -294,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
 
     # }}}
 
-    mapper = CodeGenPreprocessor(target)
+    new_outputs = inline_calls(outputs)
+    assert isinstance(new_outputs, DictOfNamedArrays)
 
-    new_outputs = copy_dict_of_named_arrays(outputs, mapper)
+    mapper = CodeGenPreprocessor(target)
+    new_outputs = copy_dict_of_named_arrays(new_outputs, mapper)
 
     return PreprocessResult(outputs=new_outputs,
-            compute_order=tuple(output_order),
-            bound_arguments=mapper.bound_arguments)
+                            compute_order=tuple(output_order),
+                            bound_arguments=mapper.bound_arguments)
 
 # vim: fdm=marker
diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py
index ad7b734..7d8760e 100644
--- a/pytato/target/loopy/codegen.py
+++ b/pytato/target/loopy/codegen.py
@@ -46,6 +46,7 @@ from pytato.target.loopy import LoopyPyOpenCLTarget, LoopyTarget, ImplSubstituti
 from pytato.transform import Mapper
 from pytato.scalar_expr import ScalarExpression, INT_CLASSES
 from pytato.codegen import preprocess, normalize_outputs, SymbolicIndex
+from pytato.function import Call, NamedCallResult
 from pytato.loopy import LoopyCall
 from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed,
                          ImplementationStrategy)
@@ -575,6 +576,19 @@ class CodeGenMapper(Mapper):
 
         state.update_kernel(kernel)
 
+    def map_named_call_result(self, expr: NamedCallResult,
+                              state: CodeGenState) -> None:
+        raise NotImplementedError("LoopyTarget does not support outlined calls"
+                                  " (yet). As a fallback, the call"
+                                  " could be inlined using"
+                                  " pt.mark_all_calls_to_be_inlined.")
+
+    def map_call(self, expr: Call, state: CodeGenState) -> None:
+        raise NotImplementedError("LoopyTarget does not support outlined calls"
+                                  " (yet). As a fallback, the call"
+                                  " could be inlined using"
+                                  " pt.mark_all_calls_to_be_inlined.")
+
 # }}}
 
 
@@ -972,36 +986,30 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]],
 
     .. note::
 
-        :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics:
-
-        - Inames that index over an :class:`~pytato.array.Array`'s axis in the
-          allocation instruction are tagged with the corresponding
-          :class:`~pytato.array.Axis`'s tags. The caller may choose to not
-          propagate axis tags of type *axis_tag_t_to_not_propagate*.
-        - :attr:`pytato.Array.tags` of inputs/outputs in *outputs*
-          would be copied over to the tags of the corresponding
-          :class:`loopy.ArrayArg`. The caller may choose to not
-          propagate array tags of type *array_tag_t_to_not_propagate*.
-        - Arrays tagged with :class:`pytato.tags.ImplStored` would have their
-          tags copied over to the tags of corresponding
-          :class:`loopy.TemporaryVariable`. The caller may choose to not
-          propagate array tags of type *array_tag_t_to_not_propagate*.
+        - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics:
+
+            - Inames that index over an :class:`~pytato.array.Array`'s axis in the
+              allocation instruction are tagged with the corresponding
+              :class:`~pytato.array.Axis`'s tags. The caller may choose to not
+              propagate axis tags of type *axis_tag_t_to_not_propagate*.
+            - :attr:`pytato.Array.tags` of inputs/outputs in *outputs*
+              would be copied over to the tags of the corresponding
+              :class:`loopy.ArrayArg`. The caller may choose to not
+              propagate array tags of type *array_tag_t_to_not_propagate*.
+            - Arrays tagged with :class:`pytato.tags.ImplStored` would have their
+              tags copied over to the tags of corresponding
+              :class:`loopy.TemporaryVariable`. The caller may choose to not
+              propagate array tags of type *array_tag_t_to_not_propagate*.
+
+    .. warning::
+
+        Currently only :class:`~pytato.function.Call` nodes that are tagged with
+        :class:`pytato.tags.InlineCallTag` can be lowered to :mod:`loopy` IR.
     """
 
     result_is_dict = isinstance(result, (dict, DictOfNamedArrays))
     orig_outputs: DictOfNamedArrays = normalize_outputs(result)
 
-    # optimization: remove any ImplStored tags on outputs to avoid redundant
-    # store-load operations (see https://github.com/inducer/pytato/issues/415)
-    orig_outputs = DictOfNamedArrays(
-        {name: (output.without_tags(ImplStored(),
-                                    verify_existence=False)
-                if not isinstance(output,
-                                  InputArgumentBase)
-                else output)
-         for name, output in orig_outputs._data.items()},
-        tags=orig_outputs.tags)
-
     del result
 
     if cl_device is not None:
@@ -1017,6 +1025,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]],
     preproc_result = preprocess(orig_outputs, target)
     outputs = preproc_result.outputs
 
+    # optimization: remove any ImplStored tags on outputs to avoid redundant
+    # store-load operations (see https://github.com/inducer/pytato/issues/415)
+    # (This must be done after all the calls have been inlined)
+    outputs = DictOfNamedArrays(
+        {name: (output.without_tags(ImplStored(),
+                                    verify_existence=False)
+                if not isinstance(output,
+                                  InputArgumentBase)
+                else output)
+         for name, output in outputs._data.items()},
+        tags=outputs.tags)
+
     compute_order = preproc_result.compute_order
 
     if options is None:
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index 1ceb4c4..953cbfa 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -96,6 +96,14 @@ Dict representation of DAGs
 .. autofunction:: tag_user_nodes
 .. autofunction:: rec_get_user_nodes
 
+
+Transforming call sites
+-----------------------
+
+.. automodule:: pytato.transform.calls
+
+.. currentmodule:: pytato.transform
+
 Internal stuff that is only here because the documentation tool wants it
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py
new file mode 100644
index 0000000..a25dd83
--- /dev/null
+++ b/pytato/transform/calls.py
@@ -0,0 +1,114 @@
+"""
+.. currentmodule:: pytato.transform.calls
+
+.. autofunction:: inline_calls
+.. autofunction:: tag_all_calls_to_be_inlined
+"""
+__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from immutables import Map
+from pytato.transform import (ArrayOrNames, CopyMapper)
+from pytato.array import (AbstractResultWithNamedArrays, Array,
+                          DictOfNamedArrays, Placeholder)
+
+from pytato.function import Call, NamedCallResult
+from pytato.tags import InlineCallTag
+
+
+# {{{ inlining
+
+class PlaceholderSubstitutor(CopyMapper):
+    """
+    .. attribute:: substitutions
+
+        A mapping from the placeholder name to the array that it is to be
+        substituted with.
+    """
+    def __init__(self, substitutions: Map[str, Array]) -> None:
+        super().__init__()
+        self.substitutions = substitutions
+
+    def map_placeholder(self, expr: Placeholder) -> Array:
+        return self.substitutions[expr.name]
+
+
+class Inliner(CopyMapper):
+    """
+    Primary mapper for :func:`inline_calls`.
+    """
+    def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+        # inline call sites within the callee.
+        new_expr = super().map_call(expr)
+        assert isinstance(new_expr, Call)
+
+        if expr.tags_of_type(InlineCallTag):
+            substitutor = PlaceholderSubstitutor(expr.bindings)
+
+            return DictOfNamedArrays(
+                {name: substitutor(ret)
+                 for name, ret in new_expr.function.returns.items()},
+                tags=expr.tags
+            )
+        else:
+            return new_expr
+
+    def map_named_call_result(self, expr: NamedCallResult) -> Array:
+        new_call = self.rec(expr._container)
+        assert isinstance(new_call, AbstractResultWithNamedArrays)
+        return new_call[expr.name]
+
+
+class InlineMarker(CopyMapper):
+    """
+    Primary mapper for :func:`tag_all_calls_to_be_inlined`.
+    """
+    def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+        return super().map_call(expr).tagged(InlineCallTag())
+
+
+def inline_calls(expr: ArrayOrNames) -> ArrayOrNames:
+    """
+    Returns a copy of *expr* with call sites tagged with
+    :class:`pytato.tags.InlineCallTag` inlined into the expression graph.
+    """
+    inliner = Inliner()
+    return inliner(expr)
+
+
+def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames:
+    """
+    Returns a copy of *expr* with all reachable instances of
+    :class:`pytato.function.Call` nodes tagged with
+    :class:`pytato.tags.InlineCallTag`.
+
+    .. note::
+
+       This routine does NOT inline calls, to inline the calls
+       use :func:`tag_all_calls_to_be_inlined` on this routine's
+       output.
+    """
+    return InlineMarker()(expr)
+
+# }}}
+
+# vim:foldmethod=marker
-- 
GitLab