From 8dab9bc9a1691f8bf4298c3646b2c6225c5e7239 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Thu, 23 Dec 2021 15:17:02 -0600
Subject: [PATCH] PytatoPyOpenCLArrayContext.compile: support returning arrays

`compile` only supported compiling callables that returned array
containers. Extends the logic to support compiling callables that simply
return thawed arrays.
---
 arraycontext/impl/pytato/compile.py | 98 +++++++++++++++++++++++------
 1 file changed, 79 insertions(+), 19 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index b98a2ad..71f98a8 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -32,6 +32,7 @@ from arraycontext.container import ArrayContainer, is_array_container_type
 from arraycontext import PytatoPyOpenCLArrayContext
 from arraycontext.container.traversal import rec_keyed_map_array_container
 
+import abc
 import numpy as np
 from typing import Any, Callable, Tuple, Dict, Mapping
 from dataclasses import dataclass, field
@@ -81,7 +82,7 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
 @dataclass(frozen=True, eq=True)
 class LeafArrayDescriptor(AbstractInputDescriptor):
     dtype: np.dtype
-    shape: Tuple[int, ...]
+    shape: pt.array.ShapeType
 
 # }}}
 
@@ -140,9 +141,14 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
                 return ary
 
             rec_keyed_map_array_container(id_collector, arg)
+        elif isinstance(arg, pt.Array):
+            arg_id = (kw,)
+            arg_id_to_arg[arg_id] = arg
+            arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype),
+                                                          arg.shape)
         else:
             raise ValueError("Argument to a compiled operator should be"
-                             " either a scalar or an array container. Got"
+                             " either a scalar, pt.Array or an array container. Got"
                              f" '{arg}'.")
 
     return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
@@ -157,6 +163,9 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
     if np.isscalar(arg):
         name = arg_id_to_name[(kw,)]
         return pt.make_placeholder(name, (), np.dtype(type(arg)))
+    elif isinstance(arg, pt.Array):
+        name = arg_id_to_name[(kw,)]
+        return pt.make_placeholder(name, arg.shape, arg.dtype)
     elif is_array_container_type(arg.__class__):
         def _rec_to_placeholder(keys, ary):
             name = arg_id_to_name[(kw,) + keys]
@@ -218,16 +227,28 @@ class LazilyCompilingFunctionCaller:
 
         return pytato_program
 
-    def _dag_to_compiled_func(self, dict_of_named_arrays,
+    def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
             input_id_to_name_in_program, output_id_to_name_in_program,
             output_template):
-        pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
-
-        return CompiledFunction(
+        if isinstance(ary_or_dict_of_named_arrays, pt.Array):
+            output_id = "_pt_out"
+            dict_of_named_arrays = pt.make_dict_of_named_arrays(
+                {output_id: ary_or_dict_of_named_arrays})
+            pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
+            return CompiledFunctionReturningArray(
                 self.actx, pytato_program,
                 input_id_to_name_in_program=input_id_to_name_in_program,
-                output_id_to_name_in_program=output_id_to_name_in_program,
-                output_template=output_template)
+                output_name_in_program=output_id)
+        elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
+            pytato_program = self._dag_to_transformed_loopy_prg(
+                ary_or_dict_of_named_arrays)
+            return CompiledFunctionReturningArrayContainer(
+                    self.actx, pytato_program,
+                    input_id_to_name_in_program=input_id_to_name_in_program,
+                    output_id_to_name_in_program=output_id_to_name_in_program,
+                    output_template=output_template)
+        else:
+            raise NotImplementedError(type(ary_or_dict_of_named_arrays))
 
     def __call__(self, *args: Any, **kwargs: Any) -> Any:
         """
@@ -261,13 +282,14 @@ class LazilyCompilingFunctionCaller:
                 **{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program)
                     for kw, arg in kwargs.items()})
 
-        if not is_array_container_type(output_template.__class__):
+        if (not (is_array_container_type(output_template.__class__)
+                 or isinstance(output_template, pt.Array))):
             # TODO: We could possibly just short-circuit this interface if the
             # returned type is a scalar. Not sure if it's worth it though.
             raise NotImplementedError(
                 f"Function '{self.f.__name__}' to be compiled "
-                "did not return an array container, but an instance of "
-                f"'{output_template.__class__}' instead.")
+                "did not return an array container or pt.Array,"
+                f" but an instance of '{output_template.__class__}' instead.")
 
         def _as_dict_of_named_arrays(keys, ary):
             name = "_pt_out_" + "_".join(str(key)
@@ -312,8 +334,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
     return input_kwargs_for_loopy
 
 
-@dataclass(frozen=True)
-class CompiledFunction:
+class CompiledFunction(abc.ABC):
     """
     A callable which captures the :class:`pytato.target.BoundProgram`  resulting
     from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of
@@ -328,6 +349,23 @@ class CompiledFunction:
         position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented
         with the leaf array's key if the argument is an array container.
 
+
+    .. automethod:: __call__
+    """
+
+    @abc.abstractmethod
+    def __call__(self, arg_id_to_arg) -> Any:
+        """
+        :arg arg_id_to_arg: Mapping from input id to the passed argument. See
+            :attr:`CompiledFunction.input_id_to_name_in_program` for input id's
+            representation.
+        """
+        pass
+
+
+@dataclass(frozen=True)
+class CompiledFunctionReturningArrayContainer(CompiledFunction):
+    """
     .. attribute:: output_id_to_name_in_program
 
         A mapping from output id to the name of
@@ -341,7 +379,6 @@ class CompiledFunction:
        An instance of :class:`arraycontext.ArrayContainer` that is the return
        type of the callable.
     """
-
     actx: PytatoPyOpenCLArrayContext
     pytato_program: pt.target.BoundProgram
     input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
@@ -349,11 +386,6 @@ class CompiledFunction:
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
-        """
-        :arg arg_id_to_arg: Mapping from input id to the passed argument. See
-            :attr:`CompiledFunction.input_id_to_name_in_program` for input id's
-            representation.
-        """
         input_kwargs_for_loopy = _args_to_cl_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
@@ -371,3 +403,31 @@ class CompiledFunction:
 
         return rec_keyed_map_array_container(to_output_template,
                                              self.output_template)
+
+
+@dataclass(frozen=True)
+class CompiledFunctionReturningArray(CompiledFunction):
+    """
+    .. attribute:: output_name_in_program
+
+        Name of the output array in the program.
+    """
+    actx: PytatoPyOpenCLArrayContext
+    pytato_program: pt.target.BoundProgram
+    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    output_name: str
+
+    def __call__(self, arg_id_to_arg) -> ArrayContainer:
+        input_kwargs_for_loopy = _args_to_cl_buffers(
+                self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
+
+        evt, out_dict = self.pytato_program(queue=self.actx.queue,
+                                            allocator=self.actx.allocator,
+                                            **input_kwargs_for_loopy)
+
+        # FIXME Kernels (for now) allocate tons of memory in temporaries. If we
+        # race too far ahead with enqueuing, there is a distinct risk of
+        # running out of memory. This mitigates that risk a bit, for now.
+        evt.wait()
+
+        return self.actx.thaw(out_dict[self.output_name])
-- 
GitLab