diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 23bf7087226ce9caad67d151d982bf523ef011b5..b6f035e4132a79d06276f5c16cc04fdfee04715c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") - t_unit = pt.generate_loopy(array, cl_device=self.queue.device) - t_unit = self.transform_loopy_program(t_unit) - evt, (cl_array,) = t_unit(self.queue) + pt_prg = pt.generate_loopy(array, cl_device=self.queue.device) + pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) + + evt, (cl_array,) = pt_prg(self.queue) evt.wait() return cl_array.with_queue(None) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 27be8cd9691f9ece3247039dd0ef7dbe74ea62dd..4f608e09243484e8a7d8dd26255e7b892c21d853 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller: :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from pytato.target.loopy import BoundPyOpenCLProgram arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args) try: @@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller: pytato_program = pt.generate_loopy(dict_of_named_arrays, options={"return_dict": True}, cl_device=self.actx.queue.device) + assert isinstance(pytato_program, BoundPyOpenCLProgram) - pytato_program.program = self.actx.transform_loopy_program(pytato_program - .program) + pytato_program = (pytato_program + .with_transformed_program(self + .actx + .transform_loopy_program)) self.program_cache[arg_id_to_descr] = CompiledFunction( self.actx, pytato_program,