From b1f8dba26ecdd64162a1a08c59e77d222077360d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Fri, 4 Nov 2022 23:44:17 -0500 Subject: [PATCH] Support mapper methods for FunctionDefintion, Call Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> --- pytato/codegen.py | 11 +- pytato/equality.py | 27 ++++ pytato/tags.py | 28 +++- pytato/transform/__init__.py | 240 ++++++++++++++++++++++++++++++++++- pytato/visualization/dot.py | 25 +++- 5 files changed, 322 insertions(+), 9 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index 63067fb..d95f96a 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -23,7 +23,7 @@ THE SOFTWARE. """ import dataclasses -from typing import Union, Dict, Tuple, List, Any +from typing import Union, Dict, Tuple, List, Any, Optional from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder, DataInterface, SizeParam, InputArgumentBase, @@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target) -> None: + def __init__(self, target: Target, + kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None + ) -> None: super().__init__() self.bound_arguments: Dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target - self.kernels_seen: Dict[str, lp.LoopKernel] = {} + self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {} + + def clone_for_callee(self) -> CodeGenPreprocessor: + return CodeGenPreprocessor(self.target, self.kernels_seen) def map_size_param(self, expr: SizeParam) -> Array: name = expr.name diff --git a/pytato/equality.py b/pytato/equality.py index 5eb1814..42c2978 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -31,6 +31,8 @@ from pytato.array import (AdvancedIndexInContiguousAxes, IndexBase, IndexLambda, NamedArray, Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) +from pytato.function import Call, NamedCallResult, FunctionDefinition +from pytools import memoize_method if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -273,6 +275,31 @@ class EqualityComparer: and expr1.tags == expr2.tags ) + @memoize_method + def map_function_definition(self, expr1: FunctionDefinition, expr2: Any + ) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.parameters == expr2.parameters + and (set(expr1.returns.keys()) == set(expr2.returns.keys())) + and all(self.rec(expr1.returns[k], expr2.returns[k]) + for k in expr1.returns) + and expr1.tags == expr2.tags + ) + + def map_call(self, expr1: Call, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.map_function_definition(expr1.function, expr2.function) + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + for name, bnd in expr1.bindings.items()) + ) + + def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self.rec(expr1._container, expr2._container)) + # }}} # vim: fdm=marker diff --git a/pytato/tags.py b/pytato/tags.py index ad6d33d..1e1180e 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -11,6 +11,8 @@ Pre-Defined Tags .. autoclass:: AssumeNonNegative .. autoclass:: ExpandedDimsReshape .. autoclass:: FunctionIdentifier +.. autoclass:: CallImplementationTag +.. autoclass:: InlineCallTag """ from typing import Tuple, Hashable @@ -131,8 +133,28 @@ class ExpandedDimsReshape(UniqueTag): @dataclass(frozen=True) class FunctionIdentifier(UniqueTag): """ - A tag that can be attached to a - :class:`~pytato.function.FunctionDefinition` node to - to describe the function's identifier. + A tag that can be attached to a :class:`~pytato.function.FunctionDefinition` + node to to describe the function's identifier. One can use this to refer + all instances of :class:`~pytato.function.FunctionDefinition`, for example in + transformations.transform.calls.concatenate_calls`. + + .. attribute:: identifier """ identifier: Hashable + + +@dataclass(frozen=True) +class CallImplementationTag(UniqueTag): + """ + A tag that can be attached to a :class:`~pytato.function.Call` node to + direct a :class:`~pytato.target.Target` how the call site should be + lowered. + """ + + +@dataclass(frozen=True) +class InlineCallTag(CallImplementationTag): + r""" + A :class:`CallImplementationTag` that directs the + :class:`pytato.target.Target` to inline the call site. + """ diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a75083f..1ceb4c4 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pytools import memoize_method + __copyright__ = """ Copyright (C) 2020 Matt Wala Copyright (C) 2020-21 Kaushik Kulkarni @@ -28,6 +30,7 @@ THE SOFTWARE. import logging import numpy as np +from immutables import Map from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, Hashable) @@ -43,10 +46,12 @@ from pytato.array import ( from pytato.distributed.nodes import ( DistributedSendRefHolder, DistributedRecv, DistributedSend) from pytato.loopy import LoopyCall, LoopyCallResult +from pytato.function import Call, NamedCallResult, FunctionDefinition from dataclasses import dataclass from pytato.tags import ImplStored from pymbolic.mapper.optimize import optimize_mapper + ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) @@ -56,6 +61,7 @@ CopyMapperResultT = TypeVar("CopyMapperResultT", # used in CopyMapper CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = FrozenSet[Array] +_SelfMapper = TypeVar("_SelfMapper", bound="Mapper") __doc__ = """ .. currentmodule:: pytato.transform @@ -100,6 +106,11 @@ Internal stuff that is only here because the documentation tool wants it .. class:: CombineT A type variable representing the type of a :class:`CombineMapper`. + +.. class:: _SelfMapper + + A type variable used to represent the type of a mapper in + :meth:`CopyMapper.clone_for_callee`. """ transform_logger = logging.getLogger(__file__) @@ -213,6 +224,8 @@ class CopyMapper(CachedMapper[ArrayOrNames]): The typical use of this mapper is to override individual ``map_`` methods in subclasses to permit term rewriting on an expression graph. + .. automethod:: clone_for_callee + .. note:: This does not copy the data of a :class:`pytato.array.DataWrapper`. @@ -229,6 +242,13 @@ class CopyMapper(CachedMapper[ArrayOrNames]): expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)() + def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars @@ -372,6 +392,32 @@ class CopyMapper(CachedMapper[ArrayOrNames]): shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + @memoize_method + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + # spawn a new mapper to avoid unsound cache hits, since the namespace of the + # function's body is different from that of the caller. + new_mapper = self.clone_for_callee() + new_returns = {name: new_mapper(ret) + for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags + ) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function), + Map({name: self.rec(bnd) + for name, bnd in expr.bindings.items()}), + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + call = self.rec(expr._container) + assert isinstance(call, Call) + return NamedCallResult(call, expr.name) + class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): """ @@ -575,6 +621,26 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> FunctionDefinition: + raise NotImplementedError("Function definitions are purposefully left" + " unimplemented as the default arguments to a new" + " DAG traversal are tricky to guess.") + + def map_call(self, expr: Call, + *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function, *args, **kwargs), + Map({name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()}), + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> Array: + call = self.rec(expr._container, *args, **kwargs) + assert isinstance(call, Call) + return NamedCallResult(call, expr.name) + # }}} @@ -685,6 +751,18 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_function_definition(self, expr: FunctionDefinition) -> CombineT: + raise NotImplementedError("Combining results from a callee expression" + " is context-dependent. Derived classes" + " must override map_function_definition.") + + def map_call(self, expr: Call) -> CombineT: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> CombineT: + return self.rec(expr._container) + # }}} @@ -752,6 +830,18 @@ class DependencyMapper(CombineMapper[R]): def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + def map_function_definition(self, expr: FunctionDefinition) -> R: + # do not include arrays from the function's body as it would involve + # putting arrays from different namespaces into the same collection. + return frozenset() + + def map_call(self, expr: Call) -> R: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> R: + return self.rec(expr._container) + # }}} @@ -796,6 +886,27 @@ class InputGatherer(CombineMapper[FrozenSet[InputArgumentBase]]): def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[InputArgumentBase]: + # get rid of placeholders local to the function. + new_mapper = InputGatherer() + all_callee_inputs = new_mapper.combine(*[new_mapper(ret) + for ret in expr.returns.values()]) + result: Set[InputArgumentBase] = set() + for inp in all_callee_inputs: + if isinstance(inp, Placeholder): + if inp.name in expr.parameters: + # drop, reference to argument + pass + else: + raise ValueError("function definition refers to non-argument " + f"placeholder named '{inp.name}'") + else: + result.add(inp) + + return frozenset(result) + # }}} @@ -814,6 +925,12 @@ class SizeParamGatherer(CombineMapper[FrozenSet[SizeParam]]): def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[SizeParam]: + return self.combine(*[self.rec(ret) + for ret in expr.returns.values()]) + # }}} @@ -830,6 +947,9 @@ class WalkMapper(Mapper): .. automethod:: post_visit """ + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + return type(self)() + def visit(self, expr: Any, *args: Any, **kwargs: Any) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. @@ -982,6 +1102,36 @@ class WalkMapper(Mapper): self.post_visit(expr, *args, **kwargs) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + self.map_function_definition(expr.function) + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.post_visit(expr) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr._container, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + # }}} @@ -1019,6 +1169,11 @@ class TopoSortMapper(CachedWalkMapper): """A mapper that creates a list of nodes in topological order. :members: topological_order + + .. note:: + + Does not consider the nodes inside a + :class:`~pytato.function.FunctionDefinition`. """ def __init__(self) -> None: @@ -1033,6 +1188,12 @@ class TopoSortMapper(CachedWalkMapper): def post_visit(self, expr: Any) -> None: # type: ignore[override] self.topological_order.append(expr) + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def map_function_definition(self, # type: ignore[override] + expr: FunctionDefinition) -> None: + # do nothing as it includes arrays from a different namespace. + return + # }}} @@ -1048,6 +1209,11 @@ class CachedMapAndCopyMapper(CopyMapper): super().__init__() self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ and does not have map_fn + return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] + # type-ignore-reason:incompatible with Mapper.rec() def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] if expr in self._cache: @@ -1102,7 +1268,14 @@ def _materialize_if_mpms(expr: Array, class MPMSMaterializer(Mapper): - """See :func:`materialize_with_mpms` for an explanation.""" + """ + See :func:`materialize_with_mpms` for an explanation. + + .. attribute:: nsuccessors + + A mapping from a node in the expression graph (i.e. an + :class:`~pytato.Array`) to its number of successors. + """ def __init__(self, nsuccessors: Mapping[Array, int]): super().__init__() self.nsuccessors = nsuccessors @@ -1252,6 +1425,52 @@ class MPMSMaterializer(Mapper): ) -> MPMSMaterializerAccumulator: return MPMSMaterializerAccumulator(frozenset([expr]), expr) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FunctionDefinition: + # spawn a new traversal here. + from pytato.analysis import get_nusers + + returns_dict_of_named_arys = DictOfNamedArrays(expr.returns) + func_nsuccessors = get_nusers(returns_dict_of_named_arys) + new_mapper = MPMSMaterializer(func_nsuccessors) + new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags) + + @memoize_method + def map_call(self, expr: Call) -> Call: + return Call(self.map_function_definition(expr.function), + Map({name: self.rec(bnd).expr + for name, bnd in expr.bindings.items()}), + tags=expr.tags) + + def map_named_call_result(self, expr: NamedCallResult + ) -> MPMSMaterializerAccumulator: + assert isinstance(expr._container, Call) + new_call = self.map_call(expr._container) + new_result = new_call[expr.name] + + assert isinstance(new_result, NamedCallResult) + assert isinstance(new_result._container, Call) + + # do not use _materialize_if_mpms as tagging a NamedArray is illegal. + if new_result.tags_of_type(ImplStored): + return MPMSMaterializerAccumulator(frozenset([new_result]), + new_result) + else: + from functools import reduce + materialized_predecessors: FrozenSet[Array] = ( + reduce(frozenset.union, + (self.rec(bnd).materialized_predecessors + for bnd in new_result._container.bindings.values()), + frozenset()) + ) + return MPMSMaterializerAccumulator(materialized_predecessors, + new_result) + # }}} @@ -1487,9 +1706,26 @@ class UsersCollector(CachedMapper[ArrayOrNames]): def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) + def map_function_definition(self, expr: FunctionDefinition, *args: Any + ) -> None: + raise AssertionError("Control shouldn't reach at this point." + " Instantiate another UsersCollector to" + " traverse the callee function.") + + def map_call(self, expr: Call, *args: Any) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + def map_named_call(self, expr: NamedCallResult, *args: Any) -> None: + assert isinstance(expr._container, Call) + for bnd in expr._container.bindings.values(): + self.node_to_users.setdefault(bnd, set()).add(expr) + + self.rec(expr._container) + def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, - Set[ArrayOrNames]]: + Set[ArrayOrNames]]: """ Returns a mapping from node in *expr* to its direct users. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 6f302fa..b836e38 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -37,6 +37,7 @@ from pytools import UniqueNameGenerator from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall +from pytato.function import Call, NamedCallResult from pytato.array import ( Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase, @@ -243,6 +244,28 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): self.nodes[expr] = info + def map_call(self, expr: Call) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges=dict(expr.bindings), + fields={ + "addr": hex(id(expr)), + "tags": stringify_tags(expr.tags), + } + ) + + def map_named_call_result(self, expr: NamedCallResult) -> None: + self.rec(expr._container) + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges={"": expr._container}, + fields={"addr": hex(id(expr)), + "name": expr.name}, + ) + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. @@ -551,4 +574,4 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) -# }}} +# vim:fdm=marker -- GitLab