From a7b372aad8163480ae587adf83b5332ee46804eb Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 16 Jun 2021 10:33:35 -0500 Subject: [PATCH] make compile independent of 'inputs_like' 'inputs_like' was a unnecessary argument to make the implementation easier. But that led to a confusing interface. Now the compiled operator is smart enough to do the inference of the input data types on its own. It maintains a mapping from the input dtypes it saw to the executable. --- arraycontext/context.py | 7 +- arraycontext/impl/pytato.py | 186 +++++++++++++++++++++++------------- 2 files changed, 124 insertions(+), 69 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index de4aa69..411c531 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -102,13 +102,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Sequence, Union, Callable, Any, Tuple +from typing import Sequence, Union, Callable, Any from abc import ABC, abstractmethod import numpy as np from pytools import memoize_method from pytools.tag import Tag -from numbers import Number # {{{ ArrayContext @@ -351,9 +350,7 @@ class ArrayContext(ABC): "setup-only" array context "leaks" into the application. """ - def compile(self, f: Callable[[Any], Any], - inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[ - ..., Any]: + def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]: """Compiles *f* for repeated use on this array context. *f* is expected to be a `pure function <https://en.wikipedia.org/wiki/Pure_function>`__ performing an array computation. diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index e8f4b44..612f4a9 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -26,17 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace from arraycontext.context import ArrayContext from arraycontext.container.traversal import \ rec_multimap_array_container, rec_map_array_container import numpy as np -from typing import Any, Callable, Tuple, Union, Sequence +from typing import Any, Callable, Tuple, Union, Sequence, Mapping from pytools.tag import Tag -from numbers import Number import loopy as lp +from dataclasses import dataclass, field +from pyrsistent import pmap, PMap class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): @@ -167,10 +167,127 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): f"(got {order})") return rec_map_array_container(_rec_ravel, a) + # }}} +class AbstractInputDescriptor: + def __eq__(self, other): + raise NotImplementedError + + def __hash__(self, other): + raise NotImplementedError + + +@dataclass(frozen=True, eq=True) +class ScalarInputDescriptor(AbstractInputDescriptor): + dtype: np.dtype + + +@dataclass(frozen=True, eq=True) +class ArrayContainerInputDescriptor(AbstractInputDescriptor): + id_to_ary_descr: "PMap[Tuple[Union[str, int], ...], Tuple[np.dtype, \ + Tuple[int, ...]]]" + + +@dataclass class PytatoCompiledOperator: + actx: ArrayContext + f: Callable[[Any], Any] + program_cache: Mapping[Tuple[AbstractInputDescriptor], + "PytatoExecutable"] = field(default_factory=lambda: {}) + + def __call__(self, *args): + + from arraycontext import (rec_keyed_map_array_container, + is_array_container) + import pytato as pt + + def to_arg_descr(arg): + if np.isscalar(arg): + return ScalarInputDescriptor(np.dtype(arg)) + elif is_array_container(arg): + id_to_ary_descr = {} + + def id_collector(keys, ary): + id_to_ary_descr[keys] = (np.dtype(ary.dtype), + ary.shape) + return ary + + rec_keyed_map_array_container(id_collector, arg) + return ArrayContainerInputDescriptor(pmap(id_to_ary_descr)) + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar or an array container. Got" + f" '{arg}'.") + + arg_descrs = tuple(to_arg_descr(arg) for arg in args) + + try: + exec_f = self.program_cache[arg_descrs] + except KeyError: + pass + else: + return exec_f(*args) + + dict_of_named_arrays = {} + # output_naming_map: result id to name of the named array in the + # generated pytato DAG. + output_naming_map = {} + # input_naming_map: argument id to placeholder name in the generated + # pytato DAG. + input_naming_map = {} + + def to_placeholder(arg, pos): + if np.isscalar(arg): + name = f"_actx_in_{pos}" + input_naming_map[(pos, )] = name + return pt.make_placeholder((), np.dtype(arg), name) + elif is_array_container(arg): + def _rec_to_placeholder(keys, ary): + name = f"_actx_in_{pos}_" + "_".join(str(key) + for key in keys) + input_naming_map[(pos,) + keys] = name + return pt.make_placeholder(ary.shape, ary.dtype, + name) + return rec_keyed_map_array_container(_rec_to_placeholder, + arg) + else: + raise NotImplementedError(type(arg)) + + outputs = self.f(*[to_placeholder(arg, iarg) + for iarg, arg in enumerate(args)]) + + if not is_array_container(outputs): + # TODO: We could possibly just short-circuit this interface if the + # returned type is a scalar. Not sure if it's worth it though. + raise ValueError(f"Function {self.f.__name__} to be compiled did not" + " return an array container.") + + def _as_dict_of_named_arrays(keys, ary): + name = "_pt_out_" + "_".join(str(key) + for key in keys) + output_naming_map[keys] = name + dict_of_named_arrays[name] = ary + return ary + + rec_keyed_map_array_container(_as_dict_of_named_arrays, + outputs) + + pytato_program = pt.generate_loopy(dict_of_named_arrays, + options={"return_dict": True}, + cl_device=self.actx.queue.device) + + self.program_cache[arg_descrs] = PytatoExecutable(self.actx, + pytato_program, + input_naming_map, + output_naming_map, + output_template=outputs) + + return self.program_cache[arg_descrs](*args) + + +class PytatoExecutable: def __init__(self, actx, pytato_program, input_id_to_name_in_program, output_id_to_name_in_program, output_template): self.actx = actx @@ -329,67 +446,8 @@ class PytatoArrayContext(ArrayContext): # }}} - def compile(self, f: Callable[[Any], Any], - inputs_like: Tuple[Union[Number, np.ndarray], ...] - ) -> Callable[..., Any]: - from arraycontext import (rec_keyed_map_array_container, - is_array_container) - import pytato as pt - - dict_of_named_arrays = {} - output_naming_map = {} - input_naming_map = {} - - def to_placeholder(input_like, pos): - if isinstance(input_like, np.number): - name = f"_actx_in_{pos}" - input_naming_map[(pos, )] = name - return pt.make_placeholder((), input_like.dtype, name) - elif is_array_container(input_like): - def _rec_to_placeholder(keys, ary): - name = f"_actx_in_{pos}_" + "_".join(str(key) - for key in keys) - input_naming_map[(pos,) + keys] = name - return pt.make_placeholder(ary.shape, ary.dtype, - name) - return rec_keyed_map_array_container(_rec_to_placeholder, - input_like) - else: - raise NotImplementedError("Unknown input type " - f"'{type(input_like)}'.") - - outputs = f(*[to_placeholder(el, iel) - for iel, el in enumerate(inputs_like)]) - - if not is_array_container(outputs): - # TODO: We could possibly just short-circuit this interface if the - # returned type is a scalar. Not sure if it's worth it though. - raise ValueError("Function to be compiled did not return an array" - " container.") - - def _as_dict_of_named_arrays(keys, ary): - name = "_pt_out_" + "_".join(str(key) - for key in keys) - output_naming_map[keys] = name - dict_of_named_arrays[name] = ary - return ary - - rec_keyed_map_array_container(_as_dict_of_named_arrays, - outputs) - - pytato_program = pt.generate_loopy(dict_of_named_arrays, - options={"return_dict": True}, - cl_device=self.queue.device) - - if False: - # transforming leads to compile-time slow downs (turning off for now) - pytato_program.program = self.transform_loopy_program(pytato_program - .program) - - return PytatoCompiledOperator(self, pytato_program, - input_naming_map, - output_naming_map, - output_template=outputs) + def compile(self, f: Callable[[Any], Any]) -> Callable[..., Any]: + return PytatoCompiledOperator(self, f) def transform_loopy_program(self, prg): from loopy.translation_unit import for_each_kernel -- GitLab