From f50736c182be71f4e673bb80fe429ac24d86f775 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 28 Jul 2021 20:22:03 -0500
Subject: [PATCH] defines PytatoPyOpenCLArrayContext.transform_dag

Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>
---
 arraycontext/impl/pytato/__init__.py | 37 ++++++++++++++++++++++++----
 arraycontext/impl/pytato/compile.py  |  6 ++++-
 2 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index b4ac63e..beaebc4 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -43,9 +43,12 @@ THE SOFTWARE.
 
 from arraycontext.context import ArrayContext
 import numpy as np
-from typing import Any, Callable, Union, Sequence
+from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
 from pytools.tag import Tag
 
+if TYPE_CHECKING:
+    import pytato
+
 
 class PytatoPyOpenCLArrayContext(ArrayContext):
     """
@@ -62,6 +65,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         to use the default allocator.
 
     .. automethod:: __init__
+
+    .. automethod:: transform_dag
     """
 
     def __init__(self, queue, allocator=None):
@@ -116,6 +121,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     def freeze(self, array):
         import pytato as pt
         import pyopencl.array as cla
+        import loopy as lp
 
         if isinstance(array, cla.Array):
             return array.with_queue(None)
@@ -134,20 +140,27 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         # }}}
 
         from arraycontext.impl.pytato.utils import _normalize_pt_expr
-        normalized_expr, bound_arguments = _normalize_pt_expr(array)
+        pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
+                {"_actx_out": array})
+
+        normalized_expr, bound_arguments = _normalize_pt_expr(
+                pt_dict_of_named_arrays)
 
         try:
             pt_prg = self._freeze_prg_cache[normalized_expr]
         except KeyError:
-            pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device)
+            pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr),
+                                       options=lp.Options(return_dict=True,
+                                                          no_numpy=True),
+                                       cl_device=self.queue.device)
             pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
             self._freeze_prg_cache[normalized_expr] = pt_prg
 
         assert len(pt_prg.bound_arguments) == 0
-        evt, (cl_array,) = pt_prg(self.queue, **bound_arguments)
+        evt, out_dict = pt_prg(self.queue, **bound_arguments)
         evt.wait()
 
-        return cl_array.with_queue(None)
+        return out_dict["_actx_out"].with_queue(None)
 
     def thaw(self, array):
         import pytato as pt
@@ -170,6 +183,20 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
                          "transform_loopy_program. Sub-classes are supposed "
                          "to implement it.")
 
+    def transform_dag(self, dag: "pytato.DictOfNamedArrays"
+                      ) -> "pytato.DictOfNamedArrays":
+        """
+        Returns a transformed version of *dag*. Sub-classes are supposed to
+        override this method to implement context-specific transformations on
+        *dag* (most likely to perform domain-specific optimizations). Every
+        :mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is
+        passed through this routine.
+
+        :arg dag: An instance of :class:`pytato.DictOfNamedArrays`
+        :returns: A transformed version of *dag*.
+        """
+        return dag
+
     def tag(self, tags: Union[Sequence[Tag], Tag], array):
         return array.tagged(tags)
 
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 5fad96c..faed2cf 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -226,7 +226,11 @@ class LazilyCompilingFunctionCaller:
                                       outputs)
 
         import loopy as lp
-        pytato_program = pt.generate_loopy(dict_of_named_arrays,
+
+        pt_dict_of_named_arrays = self.actx.transform_dag(
+            pt.make_dict_of_named_arrays(dict_of_named_arrays))
+
+        pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
                                            options=lp.Options(
                                                return_dict=True,
                                                no_numpy=True),
-- 
GitLab