From d23600f0882b659ea28cbe0826591de7bf08e3a6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 29 Jun 2021 16:27:57 -0500 Subject: [PATCH] avoid stateful update of tranforming a loopy t-unit inside a pt-program --- arraycontext/impl/pytato/__init__.py | 7 ++++--- arraycontext/impl/pytato/compile.py | 8 ++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 23bf708..b6f035e 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") - t_unit = pt.generate_loopy(array, cl_device=self.queue.device) - t_unit = self.transform_loopy_program(t_unit) - evt, (cl_array,) = t_unit(self.queue) + pt_prg = pt.generate_loopy(array, cl_device=self.queue.device) + pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) + + evt, (cl_array,) = pt_prg(self.queue) evt.wait() return cl_array.with_queue(None) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 27be8cd..4f608e0 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller: :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from pytato.target.loopy import BoundPyOpenCLProgram arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args) try: @@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller: pytato_program = pt.generate_loopy(dict_of_named_arrays, options={"return_dict": True}, cl_device=self.actx.queue.device) + assert isinstance(pytato_program, BoundPyOpenCLProgram) - pytato_program.program = self.actx.transform_loopy_program(pytato_program - .program) + pytato_program = (pytato_program + .with_transformed_program(self + .actx + .transform_loopy_program)) self.program_cache[arg_id_to_descr] = CompiledFunction( self.actx, pytato_program, -- GitLab