From d23600f0882b659ea28cbe0826591de7bf08e3a6 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 29 Jun 2021 16:27:57 -0500
Subject: [PATCH] avoid stateful update of tranforming a loopy t-unit inside a
 pt-program

---
 arraycontext/impl/pytato/__init__.py | 7 ++++---
 arraycontext/impl/pytato/compile.py  | 8 ++++++--
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 23bf708..b6f035e 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
             raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
                             f"non-pytato array of type '{type(array)}'")
 
-        t_unit = pt.generate_loopy(array, cl_device=self.queue.device)
-        t_unit = self.transform_loopy_program(t_unit)
-        evt, (cl_array,) = t_unit(self.queue)
+        pt_prg = pt.generate_loopy(array, cl_device=self.queue.device)
+        pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
+
+        evt, (cl_array,) = pt_prg(self.queue)
         evt.wait()
 
         return cl_array.with_queue(None)
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 27be8cd..4f608e0 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller:
         :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
         The intermediary pytato DAG for *args* is memoized in *self*.
         """
+        from pytato.target.loopy import BoundPyOpenCLProgram
         arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args)
 
         try:
@@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller:
         pytato_program = pt.generate_loopy(dict_of_named_arrays,
                                            options={"return_dict": True},
                                            cl_device=self.actx.queue.device)
+        assert isinstance(pytato_program, BoundPyOpenCLProgram)
 
-        pytato_program.program = self.actx.transform_loopy_program(pytato_program
-                                                                   .program)
+        pytato_program = (pytato_program
+                          .with_transformed_program(self
+                                                    .actx
+                                                    .transform_loopy_program))
 
         self.program_cache[arg_id_to_descr] = CompiledFunction(
                                                 self.actx, pytato_program,
-- 
GitLab