diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 68a76adba5572486af1e9d15763679b240358782..3cbc895fdf73bc0ec3200f2d84d2f2b5441c971d 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -33,7 +33,7 @@ from arraycontext.container.traversal import (rec_keyed_map_array_container, is_array_container) import numpy as np -from typing import Any, Callable, Tuple, Dict +from typing import Any, Callable, Tuple, Dict, Mapping from dataclasses import dataclass, field from pyrsistent import pmap, PMap @@ -65,17 +65,6 @@ class LeafArrayDescriptor: dtype: np.dtype shape: Tuple[int, ...] - -@dataclass(frozen=True, eq=True) -class ArrayContainerInputDescriptor(AbstractInputDescriptor): - """ - .. attribute id_to_ary_descr:: - - A mapping from keys of leaf arrays of an array container to their - :class:`LeafArrayDescriptor`. - """ - id_to_ary_descr: "PMap[Tuple[Any, ...], LeafArrayDescriptor]" - # }}} @@ -101,28 +90,48 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: return "_".join(_rec_str(key) for key in keys) -def _to_arg_descr(arg: Any) -> AbstractInputDescriptor: +def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] + ) -> "Tuple[PMap[Tuple[Any, ...],\ + Any],\ + PMap[Tuple[Any, ...],\ + AbstractInputDescriptor]\ + ]": """ - Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. - Returns a :class:`AbstractInputDescriptor` for a - attr:`LazilyCompilingFunctionCaller.f`'s input argument. + Helper for :meth:`LazilyCompilingFunctionCaller.__call__`. Extracts the + argument id to argument values and descriptor mappings from the input + arguments. See :attr:`CompiledFunction.input_id_to_name_in_program` for + argument-id's representation. """ - if np.isscalar(arg): - return ScalarInputDescriptor(np.dtype(arg)) - elif is_array_container(arg): - id_to_ary_descr = {} - - def id_collector(keys, ary): - id_to_ary_descr[keys] = LeafArrayDescriptor(np.dtype(ary.dtype), - ary.shape) - return ary + arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} + arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} - rec_keyed_map_array_container(id_collector, arg) - return ArrayContainerInputDescriptor(pmap(id_to_ary_descr)) - else: - raise ValueError("Argument to a compiled operator should be" - " either a scalar or an array container. Got" - f" '{arg}'.") + def to_arg_descr(iarg: int, arg: Any) -> None: + """ + Returns a :class:`AbstractInputDescriptor` for a + attr:`LazilyCompilingFunctionCaller.f`'s input argument. + """ + if np.isscalar(arg): + arg_id = (iarg,) + arg_id_to_arg[arg_id] = arg + arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(arg)) + elif is_array_container(arg): + def id_collector(keys, ary): + arg_id = (iarg,) + keys + arg_id_to_arg[arg_id] = ary + arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(ary.dtype), + ary.shape) + return ary + + rec_keyed_map_array_container(id_collector, arg) + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar or an array container. Got" + f" '{arg}'.") + + for iarg, arg in enumerate(args): + to_arg_descr(iarg, arg) + + return pmap(arg_id_to_arg), pmap(arg_id_to_descr) @dataclass @@ -141,7 +150,7 @@ class LazilyCompilingFunctionCaller: actx: PytatoPyOpenCLArrayContext f: Callable[..., Any] - program_cache: Dict[Tuple[AbstractInputDescriptor, ...], + program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", "CompiledFunction"] = field(default_factory=lambda: {}) def __call__(self, *args: Any) -> Any: @@ -151,15 +160,14 @@ class LazilyCompilingFunctionCaller: :mod:`pytato` DAG that would apply :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. """ - - arg_descrs = tuple(_to_arg_descr(arg) for arg in args) + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args) try: - compiled_f = self.program_cache[arg_descrs] + compiled_f = self.program_cache[arg_id_to_descr] except KeyError: pass else: - return compiled_f(*args) + return compiled_f(arg_id_to_arg) dict_of_named_arrays = {} # output_naming_map: result id to name of the named array in the @@ -210,13 +218,12 @@ class LazilyCompilingFunctionCaller: options={"return_dict": True}, cl_device=self.actx.queue.device) - self.program_cache[arg_descrs] = CompiledFunction(self.actx, - pytato_program, - input_naming_map, - output_naming_map, - output_template=outputs) + self.program_cache[arg_id_to_descr] = CompiledFunction( + self.actx, pytato_program, + input_naming_map, output_naming_map, + output_template=outputs) - return self.program_cache[arg_descrs](*args) + return self.program_cache[arg_id_to_descr](arg_id_to_arg) @dataclass @@ -250,44 +257,36 @@ class CompiledFunction: actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram - input_id_to_name_in_program: Dict[Tuple[Any, ...], str] - output_id_to_name_in_program: Dict[Tuple[Any, ...], str] + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] output_template: ArrayContainer - def __call__(self, *args: Any) -> ArrayContainer: - from arraycontext.container import is_array_container + def __call__(self, arg_id_to_arg) -> ArrayContainer: + """ + :arg arg_id_to_arg: Mapping from input id to the passed argument. See + :attr:`CompiledFunction.input_id_to_name_in_program` for input id's + representation. + """ from arraycontext.container.traversal import rec_keyed_map_array_container input_kwargs_to_loopy = {} # {{{ extract loopy arguments execute the program - for pos, arg in enumerate(args): + for arg_id, arg in arg_id_to_arg.items(): if np.isscalar(arg): - input_kwargs_to_loopy[self.input_id_to_name_in_program[(pos,)]] = ( - cla.to_device(self.actx.queue, np.array(arg))) - elif is_array_container(arg): - def _extract_lpy_kwargs(keys, ary): - if isinstance(ary, pt.array.DataWrapper): - processed_ary = ary.data - elif isinstance(ary, cla.Array): - processed_ary = ary - elif isinstance(ary, pt.Array): - processed_ary = (self.actx.freeze(ary) - .with_queue(self.actx.queue)) - else: - raise TypeError("Expect pytato.Array or CL-array, got " - f"{type(ary)}") - - input_kwargs_to_loopy[ - self.input_id_to_name_in_program[(pos,) - + keys]] = processed_ary - return ary - - rec_keyed_map_array_container(_extract_lpy_kwargs, arg) + arg = cla.to_device(self.actx.queue, np.array(arg)) + elif isinstance(arg, pt.array.DataWrapper): + arg = arg.data + elif isinstance(arg, cla.Array): + pass + elif isinstance(arg, pt.Array): + arg = self.actx.freeze(arg).with_queue(self.actx.queue) else: raise NotImplementedError(type(arg)) + input_kwargs_to_loopy[self.input_id_to_name_in_program[arg_id]] = arg + # {{{ the generated program might not have depended on some of the # inputs => do not pass those to the loopy kernel