From b1f8dba26ecdd64162a1a08c59e77d222077360d Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Fri, 4 Nov 2022 23:44:17 -0500
Subject: [PATCH] Support mapper methods for FunctionDefintion, Call

Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>
---
 pytato/codegen.py            |  11 +-
 pytato/equality.py           |  27 ++++
 pytato/tags.py               |  28 +++-
 pytato/transform/__init__.py | 240 ++++++++++++++++++++++++++++++++++-
 pytato/visualization/dot.py  |  25 +++-
 5 files changed, 322 insertions(+), 9 deletions(-)

diff --git a/pytato/codegen.py b/pytato/codegen.py
index 63067fb..d95f96a 100644
--- a/pytato/codegen.py
+++ b/pytato/codegen.py
@@ -23,7 +23,7 @@ THE SOFTWARE.
 """
 
 import dataclasses
-from typing import Union, Dict, Tuple, List, Any
+from typing import Union, Dict, Tuple, List, Any, Optional
 
 from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
                           DataInterface, SizeParam, InputArgumentBase,
@@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper):  # type: ignore[misc]
     ======================================  =====================================
     """
 
-    def __init__(self, target: Target) -> None:
+    def __init__(self, target: Target,
+                 kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None
+                 ) -> None:
         super().__init__()
         self.bound_arguments: Dict[str, DataInterface] = {}
         self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
         self.target = target
-        self.kernels_seen: Dict[str, lp.LoopKernel] = {}
+        self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}
+
+    def clone_for_callee(self) -> CodeGenPreprocessor:
+        return CodeGenPreprocessor(self.target, self.kernels_seen)
 
     def map_size_param(self, expr: SizeParam) -> Array:
         name = expr.name
diff --git a/pytato/equality.py b/pytato/equality.py
index 5eb1814..42c2978 100644
--- a/pytato/equality.py
+++ b/pytato/equality.py
@@ -31,6 +31,8 @@ from pytato.array import (AdvancedIndexInContiguousAxes,
                           IndexBase, IndexLambda, NamedArray,
                           Reshape, Roll, Stack, AbstractResultWithNamedArrays,
                           Array, DictOfNamedArrays, Placeholder, SizeParam)
+from pytato.function import Call, NamedCallResult, FunctionDefinition
+from pytools import memoize_method
 
 if TYPE_CHECKING:
     from pytato.loopy import LoopyCall, LoopyCallResult
@@ -273,6 +275,31 @@ class EqualityComparer:
                 and expr1.tags == expr2.tags
                 )
 
+    @memoize_method
+    def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
+                                ) -> bool:
+        return (expr1.__class__ is expr2.__class__
+                and expr1.parameters == expr2.parameters
+                and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
+                and all(self.rec(expr1.returns[k], expr2.returns[k])
+                        for k in expr1.returns)
+                and expr1.tags == expr2.tags
+                )
+
+    def map_call(self, expr1: Call, expr2: Any) -> bool:
+        return (expr1.__class__ is expr2.__class__
+                and self.map_function_definition(expr1.function, expr2.function)
+                and frozenset(expr1.bindings) == frozenset(expr2.bindings)
+                and all(self.rec(bnd,
+                                 expr2.bindings[name])
+                        for name, bnd in expr1.bindings.items())
+                )
+
+    def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
+        return (expr1.__class__ is expr2.__class__
+                and expr1.name == expr2.name
+                and self.rec(expr1._container, expr2._container))
+
 # }}}
 
 # vim: fdm=marker
diff --git a/pytato/tags.py b/pytato/tags.py
index ad6d33d..1e1180e 100644
--- a/pytato/tags.py
+++ b/pytato/tags.py
@@ -11,6 +11,8 @@ Pre-Defined Tags
 .. autoclass:: AssumeNonNegative
 .. autoclass:: ExpandedDimsReshape
 .. autoclass:: FunctionIdentifier
+.. autoclass:: CallImplementationTag
+.. autoclass:: InlineCallTag
 """
 
 from typing import Tuple, Hashable
@@ -131,8 +133,28 @@ class ExpandedDimsReshape(UniqueTag):
 @dataclass(frozen=True)
 class FunctionIdentifier(UniqueTag):
     """
-    A tag that can be attached to a
-    :class:`~pytato.function.FunctionDefinition` node to
-    to describe the function's identifier.
+    A tag that can be attached to a :class:`~pytato.function.FunctionDefinition`
+    node to to describe the function's identifier. One can use this to refer
+    all instances of :class:`~pytato.function.FunctionDefinition`, for example in
+    transformations.transform.calls.concatenate_calls`.
+
+    .. attribute:: identifier
     """
     identifier: Hashable
+
+
+@dataclass(frozen=True)
+class CallImplementationTag(UniqueTag):
+    """
+    A tag that can be attached to a :class:`~pytato.function.Call` node to
+    direct a :class:`~pytato.target.Target` how the call site should be
+    lowered.
+    """
+
+
+@dataclass(frozen=True)
+class InlineCallTag(CallImplementationTag):
+    r"""
+    A :class:`CallImplementationTag` that directs the
+    :class:`pytato.target.Target` to inline the call site.
+    """
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index a75083f..1ceb4c4 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+from pytools import memoize_method
+
 __copyright__ = """
 Copyright (C) 2020 Matt Wala
 Copyright (C) 2020-21 Kaushik Kulkarni
@@ -28,6 +30,7 @@ THE SOFTWARE.
 
 import logging
 import numpy as np
+from immutables import Map
 from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic,
                     List, Mapping, Iterable, Tuple, Optional,
                     Hashable)
@@ -43,10 +46,12 @@ from pytato.array import (
 from pytato.distributed.nodes import (
         DistributedSendRefHolder, DistributedRecv, DistributedSend)
 from pytato.loopy import LoopyCall, LoopyCallResult
+from pytato.function import Call, NamedCallResult, FunctionDefinition
 from dataclasses import dataclass
 from pytato.tags import ImplStored
 from pymbolic.mapper.optimize import optimize_mapper
 
+
 ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
 MappedT = TypeVar("MappedT",
                   Array, AbstractResultWithNamedArrays, ArrayOrNames)
@@ -56,6 +61,7 @@ CopyMapperResultT = TypeVar("CopyMapperResultT",  # used in CopyMapper
 CachedMapperT = TypeVar("CachedMapperT")  # used in CachedMapper
 IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
 R = FrozenSet[Array]
+_SelfMapper = TypeVar("_SelfMapper", bound="Mapper")
 
 __doc__ = """
 .. currentmodule:: pytato.transform
@@ -100,6 +106,11 @@ Internal stuff that is only here because the documentation tool wants it
 .. class:: CombineT
 
     A type variable representing the type of a :class:`CombineMapper`.
+
+.. class:: _SelfMapper
+
+    A type variable used to represent the type of a mapper in
+    :meth:`CopyMapper.clone_for_callee`.
 """
 
 transform_logger = logging.getLogger(__file__)
@@ -213,6 +224,8 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
     The typical use of this mapper is to override individual ``map_`` methods
     in subclasses to permit term rewriting on an expression graph.
 
+    .. automethod:: clone_for_callee
+
     .. note::
 
        This does not copy the data of a :class:`pytato.array.DataWrapper`.
@@ -229,6 +242,13 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
                  expr: CopyMapperResultT) -> CopyMapperResultT:
         return self.rec(expr)
 
+    def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+        """
+        Called to clone *self* before starting traversal of a
+        :class:`pytato.function.FunctionDefinition`.
+        """
+        return type(self)()
+
     def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
                               ) -> Tuple[IndexOrShapeExpr, ...]:
         # type-ignore-reason: apparently mypy cannot substitute typevars
@@ -372,6 +392,32 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
                shape=self.rec_idx_or_size_tuple(expr.shape),
                dtype=expr.dtype, tags=expr.tags, axes=expr.axes)
 
+    @memoize_method
+    def map_function_definition(self,
+                                expr: FunctionDefinition) -> FunctionDefinition:
+        # spawn a new mapper to avoid unsound cache hits, since the namespace of the
+        # function's body is different from that of the caller.
+        new_mapper = self.clone_for_callee()
+        new_returns = {name: new_mapper(ret)
+                       for name, ret in expr.returns.items()}
+        return FunctionDefinition(expr.parameters,
+                                  expr.return_type,
+                                  Map(new_returns),
+                                  tags=expr.tags
+                                  )
+
+    def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+        return Call(self.map_function_definition(expr.function),
+                    Map({name: self.rec(bnd)
+                         for name, bnd in expr.bindings.items()}),
+                    tags=expr.tags,
+                    )
+
+    def map_named_call_result(self, expr: NamedCallResult) -> Array:
+        call = self.rec(expr._container)
+        assert isinstance(call, Call)
+        return NamedCallResult(call, expr.name)
+
 
 class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
     """
@@ -575,6 +621,26 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
                shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs),
                dtype=expr.dtype, tags=expr.tags, axes=expr.axes)
 
+    def map_function_definition(self, expr: FunctionDefinition,
+                                *args: Any, **kwargs: Any) -> FunctionDefinition:
+        raise NotImplementedError("Function definitions are purposefully left"
+                                  " unimplemented as the default arguments to a new"
+                                  " DAG traversal are tricky to guess.")
+
+    def map_call(self, expr: Call,
+                 *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays:
+        return Call(self.map_function_definition(expr.function, *args, **kwargs),
+                    Map({name: self.rec(bnd, *args, **kwargs)
+                         for name, bnd in expr.bindings.items()}),
+                    tags=expr.tags,
+                    )
+
+    def map_named_call_result(self, expr: NamedCallResult,
+                              *args: Any, **kwargs: Any) -> Array:
+        call = self.rec(expr._container, *args, **kwargs)
+        assert isinstance(call, Call)
+        return NamedCallResult(call, expr.name)
+
 # }}}
 
 
@@ -685,6 +751,18 @@ class CombineMapper(Mapper, Generic[CombineT]):
     def map_distributed_recv(self, expr: DistributedRecv) -> CombineT:
         return self.combine(*self.rec_idx_or_size_tuple(expr.shape))
 
+    def map_function_definition(self, expr: FunctionDefinition) -> CombineT:
+        raise NotImplementedError("Combining results from a callee expression"
+                                  " is context-dependent. Derived classes"
+                                  " must override map_function_definition.")
+
+    def map_call(self, expr: Call) -> CombineT:
+        return self.combine(self.map_function_definition(expr.function),
+                            *[self.rec(bnd) for bnd in expr.bindings.values()])
+
+    def map_named_call_result(self, expr: NamedCallResult) -> CombineT:
+        return self.rec(expr._container)
+
 # }}}
 
 
@@ -752,6 +830,18 @@ class DependencyMapper(CombineMapper[R]):
     def map_distributed_recv(self, expr: DistributedRecv) -> R:
         return self.combine(frozenset([expr]), super().map_distributed_recv(expr))
 
+    def map_function_definition(self, expr: FunctionDefinition) -> R:
+        # do not include arrays from the function's body as it would involve
+        # putting arrays from different namespaces into the same collection.
+        return frozenset()
+
+    def map_call(self, expr: Call) -> R:
+        return self.combine(self.map_function_definition(expr.function),
+                            *[self.rec(bnd) for bnd in expr.bindings.values()])
+
+    def map_named_call_result(self, expr: NamedCallResult) -> R:
+        return self.rec(expr._container)
+
 # }}}
 
 
@@ -796,6 +886,27 @@ class InputGatherer(CombineMapper[FrozenSet[InputArgumentBase]]):
     def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]:
         return frozenset([expr])
 
+    @memoize_method
+    def map_function_definition(self, expr: FunctionDefinition
+                                ) -> FrozenSet[InputArgumentBase]:
+        # get rid of placeholders local to the function.
+        new_mapper = InputGatherer()
+        all_callee_inputs = new_mapper.combine(*[new_mapper(ret)
+                                                 for ret in expr.returns.values()])
+        result: Set[InputArgumentBase] = set()
+        for inp in all_callee_inputs:
+            if isinstance(inp, Placeholder):
+                if inp.name in expr.parameters:
+                    # drop, reference to argument
+                    pass
+                else:
+                    raise ValueError("function definition refers to non-argument "
+                                     f"placeholder named '{inp.name}'")
+            else:
+                result.add(inp)
+
+        return frozenset(result)
+
 # }}}
 
 
@@ -814,6 +925,12 @@ class SizeParamGatherer(CombineMapper[FrozenSet[SizeParam]]):
     def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]:
         return frozenset([expr])
 
+    @memoize_method
+    def map_function_definition(self, expr: FunctionDefinition
+                                ) -> FrozenSet[SizeParam]:
+        return self.combine(*[self.rec(ret)
+                              for ret in expr.returns.values()])
+
 # }}}
 
 
@@ -830,6 +947,9 @@ class WalkMapper(Mapper):
     .. automethod:: post_visit
     """
 
+    def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+        return type(self)()
+
     def visit(self, expr: Any, *args: Any, **kwargs: Any) -> bool:
         """
         If this method returns *True*, *expr* is traversed during the walk.
@@ -982,6 +1102,36 @@ class WalkMapper(Mapper):
 
         self.post_visit(expr, *args, **kwargs)
 
+    def map_function_definition(self, expr: FunctionDefinition,
+                                *args: Any, **kwargs: Any) -> None:
+        if not self.visit(expr):
+            return
+
+        new_mapper = self.clone_for_callee()
+        for subexpr in expr.returns.values():
+            new_mapper(subexpr, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None:
+        if not self.visit(expr):
+            return
+
+        self.map_function_definition(expr.function)
+        for bnd in expr.bindings.values():
+            self.rec(bnd)
+
+        self.post_visit(expr)
+
+    def map_named_call_result(self, expr: NamedCallResult,
+                              *args: Any, **kwargs: Any) -> None:
+        if not self.visit(expr, *args, **kwargs):
+            return
+
+        self.rec(expr._container, *args, **kwargs)
+
+        self.post_visit(expr, *args, **kwargs)
+
 # }}}
 
 
@@ -1019,6 +1169,11 @@ class TopoSortMapper(CachedWalkMapper):
     """A mapper that creates a list of nodes in topological order.
 
     :members: topological_order
+
+    .. note::
+
+        Does not consider the nodes inside  a
+        :class:`~pytato.function.FunctionDefinition`.
     """
 
     def __init__(self) -> None:
@@ -1033,6 +1188,12 @@ class TopoSortMapper(CachedWalkMapper):
     def post_visit(self, expr: Any) -> None:  # type: ignore[override]
         self.topological_order.append(expr)
 
+    # type-ignore-reason: dropped the extra `*args, **kwargs`.
+    def map_function_definition(self,  # type: ignore[override]
+                                expr: FunctionDefinition) -> None:
+        # do nothing as it includes arrays from a different namespace.
+        return
+
 # }}}
 
 
@@ -1048,6 +1209,11 @@ class CachedMapAndCopyMapper(CopyMapper):
         super().__init__()
         self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
 
+    def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+        # type-ignore-reason: self.__init__ has a different function signature
+        # than Mapper.__init__ and does not have map_fn
+        return type(self)(self.map_fn)  # type: ignore[call-arg,attr-defined]
+
     # type-ignore-reason:incompatible with Mapper.rec()
     def rec(self, expr: MappedT) -> MappedT:  # type: ignore[override]
         if expr in self._cache:
@@ -1102,7 +1268,14 @@ def _materialize_if_mpms(expr: Array,
 
 
 class MPMSMaterializer(Mapper):
-    """See :func:`materialize_with_mpms` for an explanation."""
+    """
+    See :func:`materialize_with_mpms` for an explanation.
+
+    .. attribute:: nsuccessors
+
+        A mapping from a node in the expression graph (i.e. an
+        :class:`~pytato.Array`) to its number of successors.
+    """
     def __init__(self, nsuccessors: Mapping[Array, int]):
         super().__init__()
         self.nsuccessors = nsuccessors
@@ -1252,6 +1425,52 @@ class MPMSMaterializer(Mapper):
                              ) -> MPMSMaterializerAccumulator:
         return MPMSMaterializerAccumulator(frozenset([expr]), expr)
 
+    @memoize_method
+    def map_function_definition(self, expr: FunctionDefinition
+                                ) -> FunctionDefinition:
+        # spawn a new traversal here.
+        from pytato.analysis import get_nusers
+
+        returns_dict_of_named_arys = DictOfNamedArrays(expr.returns)
+        func_nsuccessors = get_nusers(returns_dict_of_named_arys)
+        new_mapper = MPMSMaterializer(func_nsuccessors)
+        new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()}
+        return FunctionDefinition(expr.parameters,
+                                  expr.return_type,
+                                  Map(new_returns),
+                                  tags=expr.tags)
+
+    @memoize_method
+    def map_call(self, expr: Call) -> Call:
+        return Call(self.map_function_definition(expr.function),
+                        Map({name: self.rec(bnd).expr
+                             for name, bnd in expr.bindings.items()}),
+                        tags=expr.tags)
+
+    def map_named_call_result(self, expr: NamedCallResult
+                              ) -> MPMSMaterializerAccumulator:
+        assert isinstance(expr._container, Call)
+        new_call = self.map_call(expr._container)
+        new_result = new_call[expr.name]
+
+        assert isinstance(new_result, NamedCallResult)
+        assert isinstance(new_result._container, Call)
+
+        # do not use _materialize_if_mpms as tagging a NamedArray is illegal.
+        if new_result.tags_of_type(ImplStored):
+            return MPMSMaterializerAccumulator(frozenset([new_result]),
+                                               new_result)
+        else:
+            from functools import reduce
+            materialized_predecessors: FrozenSet[Array] = (
+                reduce(frozenset.union,
+                       (self.rec(bnd).materialized_predecessors
+                        for bnd in new_result._container.bindings.values()),
+                       frozenset())
+            )
+            return MPMSMaterializerAccumulator(materialized_predecessors,
+                                               new_result)
+
 # }}}
 
 
@@ -1487,9 +1706,26 @@ class UsersCollector(CachedMapper[ArrayOrNames]):
     def map_distributed_recv(self, expr: DistributedRecv) -> None:
         self.rec_idx_or_size_tuple(expr, expr.shape)
 
+    def map_function_definition(self, expr: FunctionDefinition, *args: Any
+                                ) -> None:
+        raise AssertionError("Control shouldn't reach at this point."
+                             " Instantiate another UsersCollector to"
+                             " traverse the callee function.")
+
+    def map_call(self, expr: Call, *args: Any) -> None:
+        for bnd in expr.bindings.values():
+            self.rec(bnd)
+
+    def map_named_call(self, expr: NamedCallResult, *args: Any) -> None:
+        assert isinstance(expr._container, Call)
+        for bnd in expr._container.bindings.values():
+            self.node_to_users.setdefault(bnd, set()).add(expr)
+
+        self.rec(expr._container)
+
 
 def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames,
-                                             Set[ArrayOrNames]]:
+                                          Set[ArrayOrNames]]:
     """
     Returns a mapping from node in *expr* to its direct users.
     """
diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py
index 6f302fa..b836e38 100644
--- a/pytato/visualization/dot.py
+++ b/pytato/visualization/dot.py
@@ -37,6 +37,7 @@ from pytools import UniqueNameGenerator
 from pytools.tag import Tag
 from pytools.codegen import CodeGenerator as CodeGeneratorBase
 from pytato.loopy import LoopyCall
+from pytato.function import Call, NamedCallResult
 
 from pytato.array import (
         Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase,
@@ -243,6 +244,28 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
 
         self.nodes[expr] = info
 
+    def map_call(self, expr: Call) -> None:
+        for bnd in expr.bindings.values():
+            self.rec(bnd)
+
+        self.nodes[expr] = DotNodeInfo(
+            title=expr.__class__.__name__,
+            edges=dict(expr.bindings),
+            fields={
+                "addr": hex(id(expr)),
+                "tags": stringify_tags(expr.tags),
+            }
+        )
+
+    def map_named_call_result(self, expr: NamedCallResult) -> None:
+        self.rec(expr._container)
+        self.nodes[expr] = DotNodeInfo(
+                title=expr.__class__.__name__,
+                edges={"": expr._container},
+                fields={"addr": hex(id(expr)),
+                        "name": expr.name},
+        )
+
 
 def dot_escape(s: str) -> str:
     # "\" and HTML are significant in graphviz.
@@ -551,4 +574,4 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays,
     from pytools.graphviz import show_dot
     show_dot(dot_code, **kwargs)
 
-# }}}
+# vim:fdm=marker
-- 
GitLab