From 8dab9bc9a1691f8bf4298c3646b2c6225c5e7239 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Thu, 23 Dec 2021 15:17:02 -0600 Subject: [PATCH] PytatoPyOpenCLArrayContext.compile: support returning arrays `compile` only supported compiling callables that returned array containers. Extends the logic to support compiling callables that simply return thawed arrays. --- arraycontext/impl/pytato/compile.py | 98 +++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 19 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index b98a2ad..71f98a8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -32,6 +32,7 @@ from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.container.traversal import rec_keyed_map_array_container +import abc import numpy as np from typing import Any, Callable, Tuple, Dict, Mapping from dataclasses import dataclass, field @@ -81,7 +82,7 @@ class ScalarInputDescriptor(AbstractInputDescriptor): @dataclass(frozen=True, eq=True) class LeafArrayDescriptor(AbstractInputDescriptor): dtype: np.dtype - shape: Tuple[int, ...] + shape: pt.array.ShapeType # }}} @@ -140,9 +141,14 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], return ary rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype), + arg.shape) else: raise ValueError("Argument to a compiled operator should be" - " either a scalar or an array container. Got" + " either a scalar, pt.Array or an array container. Got" f" '{arg}'.") return pmap(arg_id_to_arg), pmap(arg_id_to_descr) @@ -157,6 +163,9 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name): if np.isscalar(arg): name = arg_id_to_name[(kw,)] return pt.make_placeholder(name, (), np.dtype(type(arg))) + elif isinstance(arg, pt.Array): + name = arg_id_to_name[(kw,)] + return pt.make_placeholder(name, arg.shape, arg.dtype) elif is_array_container_type(arg.__class__): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(kw,) + keys] @@ -218,16 +227,28 @@ class LazilyCompilingFunctionCaller: return pytato_program - def _dag_to_compiled_func(self, dict_of_named_arrays, + def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, input_id_to_name_in_program, output_id_to_name_in_program, output_template): - pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) - - return CompiledFunction( + if isinstance(ary_or_dict_of_named_arrays, pt.Array): + output_id = "_pt_out" + dict_of_named_arrays = pt.make_dict_of_named_arrays( + {output_id: ary_or_dict_of_named_arrays}) + pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) + return CompiledFunctionReturningArray( self.actx, pytato_program, input_id_to_name_in_program=input_id_to_name_in_program, - output_id_to_name_in_program=output_id_to_name_in_program, - output_template=output_template) + output_name_in_program=output_id) + elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): + pytato_program = self._dag_to_transformed_loopy_prg( + ary_or_dict_of_named_arrays) + return CompiledFunctionReturningArrayContainer( + self.actx, pytato_program, + input_id_to_name_in_program=input_id_to_name_in_program, + output_id_to_name_in_program=output_id_to_name_in_program, + output_template=output_template) + else: + raise NotImplementedError(type(ary_or_dict_of_named_arrays)) def __call__(self, *args: Any, **kwargs: Any) -> Any: """ @@ -261,13 +282,14 @@ class LazilyCompilingFunctionCaller: **{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program) for kw, arg in kwargs.items()}) - if not is_array_container_type(output_template.__class__): + if (not (is_array_container_type(output_template.__class__) + or isinstance(output_template, pt.Array))): # TODO: We could possibly just short-circuit this interface if the # returned type is a scalar. Not sure if it's worth it though. raise NotImplementedError( f"Function '{self.f.__name__}' to be compiled " - "did not return an array container, but an instance of " - f"'{output_template.__class__}' instead.") + "did not return an array container or pt.Array," + f" but an instance of '{output_template.__class__}' instead.") def _as_dict_of_named_arrays(keys, ary): name = "_pt_out_" + "_".join(str(key) @@ -312,8 +334,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return input_kwargs_for_loopy -@dataclass(frozen=True) -class CompiledFunction: +class CompiledFunction(abc.ABC): """ A callable which captures the :class:`pytato.target.BoundProgram` resulting from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of @@ -328,6 +349,23 @@ class CompiledFunction: position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented with the leaf array's key if the argument is an array container. + + .. automethod:: __call__ + """ + + @abc.abstractmethod + def __call__(self, arg_id_to_arg) -> Any: + """ + :arg arg_id_to_arg: Mapping from input id to the passed argument. See + :attr:`CompiledFunction.input_id_to_name_in_program` for input id's + representation. + """ + pass + + +@dataclass(frozen=True) +class CompiledFunctionReturningArrayContainer(CompiledFunction): + """ .. attribute:: output_id_to_name_in_program A mapping from output id to the name of @@ -341,7 +379,6 @@ class CompiledFunction: An instance of :class:`arraycontext.ArrayContainer` that is the return type of the callable. """ - actx: PytatoPyOpenCLArrayContext pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] @@ -349,11 +386,6 @@ class CompiledFunction: output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - """ - :arg arg_id_to_arg: Mapping from input id to the passed argument. See - :attr:`CompiledFunction.input_id_to_name_in_program` for input id's - representation. - """ input_kwargs_for_loopy = _args_to_cl_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -371,3 +403,31 @@ class CompiledFunction: return rec_keyed_map_array_container(to_output_template, self.output_template) + + +@dataclass(frozen=True) +class CompiledFunctionReturningArray(CompiledFunction): + """ + .. attribute:: output_name_in_program + + Name of the output array in the program. + """ + actx: PytatoPyOpenCLArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_name: str + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_cl_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + evt, out_dict = self.pytato_program(queue=self.actx.queue, + allocator=self.actx.allocator, + **input_kwargs_for_loopy) + + # FIXME Kernels (for now) allocate tons of memory in temporaries. If we + # race too far ahead with enqueuing, there is a distinct risk of + # running out of memory. This mitigates that risk a bit, for now. + evt.wait() + + return self.actx.thaw(out_dict[self.output_name]) -- GitLab