Skip to content
Snippets Groups Projects
Commit c3051047 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

adds docs for Pytato(CompiledOperator|Executable)

parent e8ab5014
No related branches found
No related tags found
No related merge requests found
"""
.. currentmodule:: arraycontext
.. autoclass:: PytatoPyOpenCLArrayContext
.. automodule:: arraycontext.impl.pytato.compile
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......@@ -84,7 +86,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
def call_loopy(self, program, **kwargs):
import pyopencl.array as cla
from pytato.loopy import call_loopy # type: ignore
from pytato.loopy import call_loopy
entrypoint, = set(program.callables_table)
# thaw frozen arrays
......@@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
# }}}
def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]:
from arraycontext.impl.pytato import PytatoCompiledOperator
from arraycontext.impl.pytato.compile import PytatoCompiledOperator
return PytatoCompiledOperator(self, f)
def transform_loopy_program(self, prg):
......
"""
.. currentmodule:: arraycontext
.. autoclass:: PytatoPyOpenCLArrayContext
.. currentmodule:: arraycontext.impl.pytato.compile
.. autoclass:: PytatoCompiledOperator
.. autoclass:: PytatoExecutable
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
......@@ -26,9 +27,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from arraycontext.context import ArrayContext
from arraycontext.container import ArrayContainer
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
import numpy as np
from typing import Any, Callable, Tuple, Union, Mapping
from typing import Any, Callable, Tuple, Dict
from dataclasses import dataclass, field
from pyrsistent import pmap, PMap
......@@ -51,8 +53,8 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
@dataclass(frozen=True, eq=True)
class ArrayContainerInputDescriptor(AbstractInputDescriptor):
id_to_ary_descr: "PMap[Tuple[Union[str, int], ...], Tuple[np.dtype, \
Tuple[int, ...]]]"
id_to_ary_descr: "PMap[Tuple[Any, ...], Tuple[np.dtype, \
Tuple[int, ...]]]"
def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
......@@ -74,17 +76,35 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
@dataclass
class PytatoCompiledOperator:
actx: ArrayContext
f: Callable[[Any], Any]
program_cache: Mapping[Tuple[AbstractInputDescriptor],
"PytatoExecutable"] = field(default_factory=lambda: {})
"""
Records a side-effect-free callable :attr:`PytatoCompiledOperator.f`, that
would be specialized for different input types
:meth:`PytatoCompiledOperator.__call__` is invoked with.
... attribute f::
The callable that would be specialized into :mod:`pytato` DAGs.
.. automethod:: __call__
"""
def __call__(self, *args):
actx: PytatoPyOpenCLArrayContext
f: Callable[..., Any]
program_cache: Dict[Tuple[AbstractInputDescriptor, ...],
"PytatoExecutable"] = field(default_factory=lambda: {})
def __call__(self, *args: Any) -> Any:
"""
Mimics :attr:`~PytatoCompiledOperator.f` being called with *args*.
Before calling :attr:`~PytatoCompiledOperator.f`, it is compiled to a
:mod:`pytato` DAG that would apply :attr:`~PytatoCompiledOperator.f`
with *args* in a lazy-sense.
"""
from arraycontext.container.traversal import (rec_keyed_map_array_container,
is_array_container)
def to_arg_descr(arg):
def to_arg_descr(arg: Any) -> AbstractInputDescriptor:
if np.isscalar(arg):
return ScalarInputDescriptor(np.dtype(arg))
elif is_array_container(arg):
......@@ -168,16 +188,40 @@ class PytatoCompiledOperator:
return self.program_cache[arg_descrs](*args)
@dataclass
class PytatoExecutable:
def __init__(self, actx, pytato_program, input_id_to_name_in_program,
output_id_to_name_in_program, output_template):
self.actx = actx
self.pytato_program = pytato_program
self.input_id_to_name_in_program = input_id_to_name_in_program
self.output_id_to_name_in_program = output_id_to_name_in_program
self.output_template = output_template
def __call__(self, *args):
"""
A callable which is an instance of :attr:`~PytatoCompiledOperator.f`
specialized for a particular input type fed to it.
.. attribute:: input_id_to_name_in_program
A mapping from input id to the placholder name in
:attr:`PytatoExecutable.pytato_program`. Input id is represented as the
position of :attr:`~PytatoCompiledOperator.f`'s argument augmented with
the leaf array's key if the argument is an array container.
.. attribute:: output_id_to_name_in_program
A mapping from output id to the name of
:class:`pytato.array.NamedArray` in
:attr:`PytatoExecutable.pytato_program`. Output id is represented by
the key of a leaf array in the array container
:attr:`PytatoExecutable.output_template`.
.. attribute:: output_template
An instance of :class:`arraycontext.ArrayContainer` that is the return
type of the callable.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Dict[Tuple[Any, ...], str]
output_id_to_name_in_program: Dict[Tuple[Any, ...], str]
output_template: ArrayContainer
def __call__(self, *args: Any) -> ArrayContainer:
from arraycontext.container import is_array_container
from arraycontext.container.traversal import rec_keyed_map_array_container
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment