From c3051047729fdb007fffd611bac5edb2762a98b6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sun, 20 Jun 2021 20:24:44 -0500 Subject: [PATCH] adds docs for Pytato(CompiledOperator|Executable) --- arraycontext/impl/pytato/__init__.py | 6 +- arraycontext/impl/pytato/compile.py | 86 +++++++++++++++++++++------- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 2a4556c..e42da63 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -1,6 +1,8 @@ """ .. currentmodule:: arraycontext .. autoclass:: PytatoPyOpenCLArrayContext + +.. automodule:: arraycontext.impl.pytato.compile """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -84,7 +86,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def call_loopy(self, program, **kwargs): import pyopencl.array as cla - from pytato.loopy import call_loopy # type: ignore + from pytato.loopy import call_loopy entrypoint, = set(program.callables_table) # thaw frozen arrays @@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): # }}} def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]: - from arraycontext.impl.pytato import PytatoCompiledOperator + from arraycontext.impl.pytato.compile import PytatoCompiledOperator return PytatoCompiledOperator(self, f) def transform_loopy_program(self, prg): diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 24f17a2..81895a4 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -1,6 +1,7 @@ """ -.. currentmodule:: arraycontext -.. autoclass:: PytatoPyOpenCLArrayContext +.. currentmodule:: arraycontext.impl.pytato.compile +.. autoclass:: PytatoCompiledOperator +.. autoclass:: PytatoExecutable """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -26,9 +27,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.context import ArrayContext +from arraycontext.container import ArrayContainer +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext import numpy as np -from typing import Any, Callable, Tuple, Union, Mapping +from typing import Any, Callable, Tuple, Dict from dataclasses import dataclass, field from pyrsistent import pmap, PMap @@ -51,8 +53,8 @@ class ScalarInputDescriptor(AbstractInputDescriptor): @dataclass(frozen=True, eq=True) class ArrayContainerInputDescriptor(AbstractInputDescriptor): - id_to_ary_descr: "PMap[Tuple[Union[str, int], ...], Tuple[np.dtype, \ - Tuple[int, ...]]]" + id_to_ary_descr: "PMap[Tuple[Any, ...], Tuple[np.dtype, \ + Tuple[int, ...]]]" def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: @@ -74,17 +76,35 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: @dataclass class PytatoCompiledOperator: - actx: ArrayContext - f: Callable[[Any], Any] - program_cache: Mapping[Tuple[AbstractInputDescriptor], - "PytatoExecutable"] = field(default_factory=lambda: {}) + """ + Records a side-effect-free callable :attr:`PytatoCompiledOperator.f`, that + would be specialized for different input types + :meth:`PytatoCompiledOperator.__call__` is invoked with. + + ... attribute f:: + + The callable that would be specialized into :mod:`pytato` DAGs. + + .. automethod:: __call__ + """ - def __call__(self, *args): + actx: PytatoPyOpenCLArrayContext + f: Callable[..., Any] + program_cache: Dict[Tuple[AbstractInputDescriptor, ...], + "PytatoExecutable"] = field(default_factory=lambda: {}) + + def __call__(self, *args: Any) -> Any: + """ + Mimics :attr:`~PytatoCompiledOperator.f` being called with *args*. + Before calling :attr:`~PytatoCompiledOperator.f`, it is compiled to a + :mod:`pytato` DAG that would apply :attr:`~PytatoCompiledOperator.f` + with *args* in a lazy-sense. + """ from arraycontext.container.traversal import (rec_keyed_map_array_container, is_array_container) - def to_arg_descr(arg): + def to_arg_descr(arg: Any) -> AbstractInputDescriptor: if np.isscalar(arg): return ScalarInputDescriptor(np.dtype(arg)) elif is_array_container(arg): @@ -168,16 +188,40 @@ class PytatoCompiledOperator: return self.program_cache[arg_descrs](*args) +@dataclass class PytatoExecutable: - def __init__(self, actx, pytato_program, input_id_to_name_in_program, - output_id_to_name_in_program, output_template): - self.actx = actx - self.pytato_program = pytato_program - self.input_id_to_name_in_program = input_id_to_name_in_program - self.output_id_to_name_in_program = output_id_to_name_in_program - self.output_template = output_template - - def __call__(self, *args): + """ + A callable which is an instance of :attr:`~PytatoCompiledOperator.f` + specialized for a particular input type fed to it. + + .. attribute:: input_id_to_name_in_program + + A mapping from input id to the placholder name in + :attr:`PytatoExecutable.pytato_program`. Input id is represented as the + position of :attr:`~PytatoCompiledOperator.f`'s argument augmented with + the leaf array's key if the argument is an array container. + + .. attribute:: output_id_to_name_in_program + + A mapping from output id to the name of + :class:`pytato.array.NamedArray` in + :attr:`PytatoExecutable.pytato_program`. Output id is represented by + the key of a leaf array in the array container + :attr:`PytatoExecutable.output_template`. + + .. attribute:: output_template + + An instance of :class:`arraycontext.ArrayContainer` that is the return + type of the callable. + """ + + actx: PytatoPyOpenCLArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Dict[Tuple[Any, ...], str] + output_id_to_name_in_program: Dict[Tuple[Any, ...], str] + output_template: ArrayContainer + + def __call__(self, *args: Any) -> ArrayContainer: from arraycontext.container import is_array_container from arraycontext.container.traversal import rec_keyed_map_array_container -- GitLab