From eef28550964acb3a3bfd3e76342e1149674a16c5 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Thu, 31 Mar 2022 23:38:45 -0500
Subject: [PATCH] transform DAG before making it a placeholder

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
---
 arraycontext/impl/pytato/compile.py | 53 ++++++++++++++++++++++++++---
 1 file changed, 48 insertions(+), 5 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index d83e376..8b20a7c 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -28,7 +28,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext.container import ArrayContainer, is_array_container_type
+from arraycontext.container import (ArrayContainer, is_array_container_type,
+                                    ArrayT)
 from arraycontext import PytatoPyOpenCLArrayContext
 from arraycontext.container.traversal import rec_keyed_map_array_container
 
@@ -154,7 +155,40 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
     return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
 
 
-def _get_f_placeholder_args(arg, kw, arg_id_to_name):
+def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
+    """
+    Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
+    in :meth:`LazilyCompilingFunctionCaller.__call__`.
+
+    Preprocessing here refers to:
+
+    - Metadata Inference that is supplied via *actx*\'s
+      :meth:`PytatoPyOpenCLArrayContext.transform_dag`.
+    """
+    import pyopencl.array as cla
+    from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
+                                                              TaggableCLArray)
+    if isinstance(ary, pt.Array):
+        dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
+        # Transform the DAG to give metadata inference a chance to do its job
+        return actx.transform_dag(dag)["_actx_out"].expr
+    elif isinstance(ary, TaggableCLArray):
+        return ary
+    elif isinstance(ary, cla.Array):
+        from warnings import warn
+        warn("Passing pyopencl.array.Array to a compiled callable"
+             " is deprecated and will stop working in 2023."
+             " Use `to_tagged_cl_array` to convert the array to"
+             " TaggableCLArray", DeprecationWarning, stacklevel=2)
+
+        return to_tagged_cl_array(ary,
+                                  axes=None,
+                                  tags=frozenset())
+    else:
+        raise NotImplementedError(type(ary))
+
+
+def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
     """
     Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the
     placeholder version of an argument to
@@ -165,10 +199,16 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
         return pt.make_placeholder(name, (), np.dtype(type(arg)))
     elif isinstance(arg, pt.Array):
         name = arg_id_to_name[(kw,)]
-        return pt.make_placeholder(name, arg.shape, arg.dtype)
+        # Transform the DAG to give metadata inference a chance to do its job
+        arg = _to_input_for_compiled(arg, actx)
+        return pt.make_placeholder(name, arg.shape, arg.dtype,
+                                   axes=arg.axes,
+                                   tags=arg.tags)
     elif is_array_container_type(arg.__class__):
         def _rec_to_placeholder(keys, ary):
             name = arg_id_to_name[(kw,) + keys]
+            # Transform the DAG to give metadata inference a chance to do its job
+            ary = _to_input_for_compiled(ary, actx)
             return pt.make_placeholder(name,
                                        ary.shape,
                                        ary.dtype,
@@ -293,9 +333,12 @@ class LazilyCompilingFunctionCaller:
             for arg_id in arg_id_to_arg}
 
         output_template = self.f(
-                *[_get_f_placeholder_args(arg, iarg, input_id_to_name_in_program)
+                *[_get_f_placeholder_args(arg, iarg,
+                                          input_id_to_name_in_program, self.actx)
                     for iarg, arg in enumerate(args)],
-                **{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program)
+                **{kw: _get_f_placeholder_args(arg, kw,
+                                               input_id_to_name_in_program,
+                                               self.actx)
                     for kw, arg in kwargs.items()})
 
         if (not (is_array_container_type(output_template.__class__)
-- 
GitLab