From eef28550964acb3a3bfd3e76342e1149674a16c5 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Thu, 31 Mar 2022 23:38:45 -0500 Subject: [PATCH] transform DAG before making it a placeholder Co-authored-by: Andreas Kloeckner <inform@tiker.net> --- arraycontext/impl/pytato/compile.py | 53 ++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index d83e376..8b20a7c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -28,7 +28,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.container import (ArrayContainer, is_array_container_type, + ArrayT) from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.container.traversal import rec_keyed_map_array_container @@ -154,7 +155,40 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], return pmap(arg_id_to_arg), pmap(arg_id_to_descr) -def _get_f_placeholder_args(arg, kw, arg_id_to_name): +def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): + """ + Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` + in :meth:`LazilyCompilingFunctionCaller.__call__`. + + Preprocessing here refers to: + + - Metadata Inference that is supplied via *actx*\'s + :meth:`PytatoPyOpenCLArrayContext.transform_dag`. + """ + import pyopencl.array as cla + from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, + TaggableCLArray) + if isinstance(ary, pt.Array): + dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) + # Transform the DAG to give metadata inference a chance to do its job + return actx.transform_dag(dag)["_actx_out"].expr + elif isinstance(ary, TaggableCLArray): + return ary + elif isinstance(ary, cla.Array): + from warnings import warn + warn("Passing pyopencl.array.Array to a compiled callable" + " is deprecated and will stop working in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray", DeprecationWarning, stacklevel=2) + + return to_tagged_cl_array(ary, + axes=None, + tags=frozenset()) + else: + raise NotImplementedError(type(ary)) + + +def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): """ Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to @@ -165,10 +199,16 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name): return pt.make_placeholder(name, (), np.dtype(type(arg))) elif isinstance(arg, pt.Array): name = arg_id_to_name[(kw,)] - return pt.make_placeholder(name, arg.shape, arg.dtype) + # Transform the DAG to give metadata inference a chance to do its job + arg = _to_input_for_compiled(arg, actx) + return pt.make_placeholder(name, arg.shape, arg.dtype, + axes=arg.axes, + tags=arg.tags) elif is_array_container_type(arg.__class__): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(kw,) + keys] + # Transform the DAG to give metadata inference a chance to do its job + ary = _to_input_for_compiled(ary, actx) return pt.make_placeholder(name, ary.shape, ary.dtype, @@ -293,9 +333,12 @@ class LazilyCompilingFunctionCaller: for arg_id in arg_id_to_arg} output_template = self.f( - *[_get_f_placeholder_args(arg, iarg, input_id_to_name_in_program) + *[_get_f_placeholder_args(arg, iarg, + input_id_to_name_in_program, self.actx) for iarg, arg in enumerate(args)], - **{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program) + **{kw: _get_f_placeholder_args(arg, kw, + input_id_to_name_in_program, + self.actx) for kw, arg in kwargs.items()}) if (not (is_array_container_type(output_template.__class__) -- GitLab