From f50736c182be71f4e673bb80fe429ac24d86f775 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 28 Jul 2021 20:22:03 -0500 Subject: [PATCH] defines PytatoPyOpenCLArrayContext.transform_dag Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> --- arraycontext/impl/pytato/__init__.py | 37 ++++++++++++++++++++++++---- arraycontext/impl/pytato/compile.py | 6 ++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index b4ac63e..beaebc4 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -43,9 +43,12 @@ THE SOFTWARE. from arraycontext.context import ArrayContext import numpy as np -from typing import Any, Callable, Union, Sequence +from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag +if TYPE_CHECKING: + import pytato + class PytatoPyOpenCLArrayContext(ArrayContext): """ @@ -62,6 +65,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): to use the default allocator. .. automethod:: __init__ + + .. automethod:: transform_dag """ def __init__(self, queue, allocator=None): @@ -116,6 +121,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def freeze(self, array): import pytato as pt import pyopencl.array as cla + import loopy as lp if isinstance(array, cla.Array): return array.with_queue(None) @@ -134,20 +140,27 @@ class PytatoPyOpenCLArrayContext(ArrayContext): # }}} from arraycontext.impl.pytato.utils import _normalize_pt_expr - normalized_expr, bound_arguments = _normalize_pt_expr(array) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( + {"_actx_out": array}) + + normalized_expr, bound_arguments = _normalize_pt_expr( + pt_dict_of_named_arrays) try: pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: - pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device) + pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), + options=lp.Options(return_dict=True, + no_numpy=True), + cl_device=self.queue.device) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg assert len(pt_prg.bound_arguments) == 0 - evt, (cl_array,) = pt_prg(self.queue, **bound_arguments) + evt, out_dict = pt_prg(self.queue, **bound_arguments) evt.wait() - return cl_array.with_queue(None) + return out_dict["_actx_out"].with_queue(None) def thaw(self, array): import pytato as pt @@ -170,6 +183,20 @@ class PytatoPyOpenCLArrayContext(ArrayContext): "transform_loopy_program. Sub-classes are supposed " "to implement it.") + def transform_dag(self, dag: "pytato.DictOfNamedArrays" + ) -> "pytato.DictOfNamedArrays": + """ + Returns a transformed version of *dag*. Sub-classes are supposed to + override this method to implement context-specific transformations on + *dag* (most likely to perform domain-specific optimizations). Every + :mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is + passed through this routine. + + :arg dag: An instance of :class:`pytato.DictOfNamedArrays` + :returns: A transformed version of *dag*. + """ + return dag + def tag(self, tags: Union[Sequence[Tag], Tag], array): return array.tagged(tags) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5fad96c..faed2cf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -226,7 +226,11 @@ class LazilyCompilingFunctionCaller: outputs) import loopy as lp - pytato_program = pt.generate_loopy(dict_of_named_arrays, + + pt_dict_of_named_arrays = self.actx.transform_dag( + pt.make_dict_of_named_arrays(dict_of_named_arrays)) + + pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, options=lp.Options( return_dict=True, no_numpy=True), -- GitLab