From a567041ab9fb4b394178985503060e79494ffe18 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 28 Jun 2021 12:03:11 -0500
Subject: [PATCH] make PytatoPyOpenCLArrayContext.transform_loopy_program an
 identity map

---
 arraycontext/impl/pytato/__init__.py | 61 ++++------------------------
 arraycontext/impl/pytato/compile.py  |  3 ++
 2 files changed, 11 insertions(+), 53 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 3e3339f..480cb1c 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -101,6 +101,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         return call_loopy(program, kwargs, entrypoint)
 
     def freeze(self, array):
+        # TODO: This should store a cache of pytato DAG -> build pyopencl
+        # program instead of re-compiling the DAG for every freeze.
+
         import pytato as pt
         import pyopencl.array as cla
 
@@ -110,8 +113,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
             raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
                             f"non-pytato array of type '{type(array)}'")
 
-        prg = pt.generate_loopy(array, cl_device=self.queue.device)
-        evt, (cl_array,) = prg(self.queue)
+        t_unit = pt.generate_loopy(array, cl_device=self.queue.device)
+        t_unit = self.transform_loopy_program(t_unit)
+        evt, (cl_array,) = t_unit(self.queue)
         evt.wait()
 
         return cl_array.with_queue(None)
@@ -132,7 +136,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller
         return LazilyCompilingFunctionCaller(self, f)
 
-    def transform_loopy_program(self, prg):
+    def transform_loopy_program(self, t_unit):
         from warnings import warn
         warn(
             "Using arraycontext.PytatoPyOpenCLArrayContext.transform_loopy_program "
@@ -144,56 +148,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
             "to build on.",
             DeprecationWarning, stacklevel=2)
 
-        from loopy.translation_unit import for_each_kernel
-
-        nwg = 48
-        nwi = (16, 2)
-
-        @for_each_kernel
-        def gridify(knl):
-            # {{{ Pattern matching inames
-
-            for insn in knl.instructions:
-                if isinstance(insn, lp.CallInstruction):
-                    # must be a callable kernel, don't touch.
-                    pass
-                elif isinstance(insn, lp.Assignment):
-                    bigger_loop = None
-                    smaller_loop = None
-                    for iname in insn.within_inames:
-                        if iname.startswith("iel"):
-                            assert bigger_loop is None
-                            bigger_loop = iname
-                        if iname.startswith("idof"):
-                            assert smaller_loop is None
-                            smaller_loop = iname
-
-                    if bigger_loop or smaller_loop:
-                        assert bigger_loop is not None and smaller_loop is not None
-                    else:
-                        sorted_inames = sorted(tuple(insn.within_inames),
-                                key=knl.get_constant_iname_length)
-                        smaller_loop = sorted_inames[0]
-                        bigger_loop = sorted_inames[1]
-
-                    knl = lp.chunk_iname(knl, bigger_loop, nwg,
-                            outer_tag="g.0")
-                    knl = lp.split_iname(knl, f"{bigger_loop}_inner",
-                            nwi[0], inner_tag="l.1")
-                    knl = lp.split_iname(knl, smaller_loop,
-                            nwi[1], inner_tag="l.0")
-                elif isinstance(insn, lp.BarrierInstruction):
-                    pass
-                else:
-                    raise NotImplementedError
-
-            # }}}
-
-            return knl
-
-        prg = lp.set_options(prg, "insert_additional_gbarriers")
-
-        return gridify(prg)
+        return t_unit
 
     def tag(self, tags: Union[Sequence[Tag], Tag], array):
         return array.tagged(tags)
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index f169ea7..31baed1 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -221,6 +221,9 @@ class LazilyCompilingFunctionCaller:
                                            options={"return_dict": True},
                                            cl_device=self.actx.queue.device)
 
+        pytato_program.program = self.actx.transform_loopy_program(pytato_program
+                                                                   .program)
+
         self.program_cache[arg_id_to_descr] = CompiledFunction(
                                                 self.actx, pytato_program,
                                                 input_naming_map, output_naming_map,
-- 
GitLab