From c3051047729fdb007fffd611bac5edb2762a98b6 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sun, 20 Jun 2021 20:24:44 -0500
Subject: [PATCH] adds docs for Pytato(CompiledOperator|Executable)

---
 arraycontext/impl/pytato/__init__.py |  6 +-
 arraycontext/impl/pytato/compile.py  | 86 +++++++++++++++++++++-------
 2 files changed, 69 insertions(+), 23 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 2a4556c..e42da63 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -1,6 +1,8 @@
 """
 .. currentmodule:: arraycontext
 .. autoclass:: PytatoPyOpenCLArrayContext
+
+.. automodule:: arraycontext.impl.pytato.compile
 """
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -84,7 +86,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
 
     def call_loopy(self, program, **kwargs):
         import pyopencl.array as cla
-        from pytato.loopy import call_loopy  # type: ignore
+        from pytato.loopy import call_loopy
         entrypoint, = set(program.callables_table)
 
         # thaw frozen arrays
@@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     # }}}
 
     def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]:
-        from arraycontext.impl.pytato import PytatoCompiledOperator
+        from arraycontext.impl.pytato.compile import PytatoCompiledOperator
         return PytatoCompiledOperator(self, f)
 
     def transform_loopy_program(self, prg):
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 24f17a2..81895a4 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -1,6 +1,7 @@
 """
-.. currentmodule:: arraycontext
-.. autoclass:: PytatoPyOpenCLArrayContext
+.. currentmodule:: arraycontext.impl.pytato.compile
+.. autoclass:: PytatoCompiledOperator
+.. autoclass:: PytatoExecutable
 """
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -26,9 +27,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext.context import ArrayContext
+from arraycontext.container import ArrayContainer
+from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
 import numpy as np
-from typing import Any, Callable, Tuple, Union, Mapping
+from typing import Any, Callable, Tuple, Dict
 from dataclasses import dataclass, field
 from pyrsistent import pmap, PMap
 
@@ -51,8 +53,8 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
 
 @dataclass(frozen=True, eq=True)
 class ArrayContainerInputDescriptor(AbstractInputDescriptor):
-    id_to_ary_descr: "PMap[Tuple[Union[str, int], ...], Tuple[np.dtype, \
-                                                        Tuple[int, ...]]]"
+    id_to_ary_descr: "PMap[Tuple[Any, ...], Tuple[np.dtype, \
+                                                  Tuple[int, ...]]]"
 
 
 def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
@@ -74,17 +76,35 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
 
 @dataclass
 class PytatoCompiledOperator:
-    actx: ArrayContext
-    f: Callable[[Any], Any]
-    program_cache: Mapping[Tuple[AbstractInputDescriptor],
-                           "PytatoExecutable"] = field(default_factory=lambda: {})
+    """
+    Records a side-effect-free callable :attr:`PytatoCompiledOperator.f`, that
+    would be specialized for different input types
+    :meth:`PytatoCompiledOperator.__call__` is invoked with.
+
+    ... attribute f::
+
+        The callable that would be specialized into :mod:`pytato` DAGs.
+
+    .. automethod:: __call__
+    """
 
-    def __call__(self, *args):
+    actx: PytatoPyOpenCLArrayContext
+    f: Callable[..., Any]
+    program_cache: Dict[Tuple[AbstractInputDescriptor, ...],
+                        "PytatoExecutable"] = field(default_factory=lambda: {})
+
+    def __call__(self, *args: Any) -> Any:
+        """
+        Mimics :attr:`~PytatoCompiledOperator.f` being called with *args*.
+        Before calling :attr:`~PytatoCompiledOperator.f`, it is compiled to a
+        :mod:`pytato` DAG that would apply :attr:`~PytatoCompiledOperator.f`
+        with *args* in a lazy-sense.
+        """
 
         from arraycontext.container.traversal import (rec_keyed_map_array_container,
                                                       is_array_container)
 
-        def to_arg_descr(arg):
+        def to_arg_descr(arg: Any) -> AbstractInputDescriptor:
             if np.isscalar(arg):
                 return ScalarInputDescriptor(np.dtype(arg))
             elif is_array_container(arg):
@@ -168,16 +188,40 @@ class PytatoCompiledOperator:
         return self.program_cache[arg_descrs](*args)
 
 
+@dataclass
 class PytatoExecutable:
-    def __init__(self, actx, pytato_program, input_id_to_name_in_program,
-                 output_id_to_name_in_program, output_template):
-        self.actx = actx
-        self.pytato_program = pytato_program
-        self.input_id_to_name_in_program = input_id_to_name_in_program
-        self.output_id_to_name_in_program = output_id_to_name_in_program
-        self.output_template = output_template
-
-    def __call__(self, *args):
+    """
+    A callable which is an instance of :attr:`~PytatoCompiledOperator.f`
+    specialized for a particular input type fed to it.
+
+    .. attribute:: input_id_to_name_in_program
+
+        A mapping from input id to the placholder name in
+        :attr:`PytatoExecutable.pytato_program`. Input id is represented as the
+        position of :attr:`~PytatoCompiledOperator.f`'s argument augmented with
+        the leaf array's key if the argument is an array container.
+
+    .. attribute:: output_id_to_name_in_program
+
+        A mapping from output id to the name of
+        :class:`pytato.array.NamedArray` in
+        :attr:`PytatoExecutable.pytato_program`. Output id is represented by
+        the key of a leaf array in the array container
+        :attr:`PytatoExecutable.output_template`.
+
+    .. attribute:: output_template
+
+       An instance of :class:`arraycontext.ArrayContainer` that is the return
+       type of the callable.
+    """
+
+    actx: PytatoPyOpenCLArrayContext
+    pytato_program: pt.target.BoundProgram
+    input_id_to_name_in_program: Dict[Tuple[Any, ...], str]
+    output_id_to_name_in_program: Dict[Tuple[Any, ...], str]
+    output_template: ArrayContainer
+
+    def __call__(self, *args: Any) -> ArrayContainer:
         from arraycontext.container import is_array_container
         from arraycontext.container.traversal import rec_keyed_map_array_container
 
-- 
GitLab