Skip to content
Snippets Groups Projects
Commit d23600f0 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

avoid stateful update of tranforming a loopy t-unit inside a pt-program

parent 2eb6c64a
No related branches found
No related tags found
No related merge requests found
......@@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
f"non-pytato array of type '{type(array)}'")
t_unit = pt.generate_loopy(array, cl_device=self.queue.device)
t_unit = self.transform_loopy_program(t_unit)
evt, (cl_array,) = t_unit(self.queue)
pt_prg = pt.generate_loopy(array, cl_device=self.queue.device)
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
evt, (cl_array,) = pt_prg(self.queue)
evt.wait()
return cl_array.with_queue(None)
......
......@@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller:
:attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
The intermediary pytato DAG for *args* is memoized in *self*.
"""
from pytato.target.loopy import BoundPyOpenCLProgram
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args)
try:
......@@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller:
pytato_program = pt.generate_loopy(dict_of_named_arrays,
options={"return_dict": True},
cl_device=self.actx.queue.device)
assert isinstance(pytato_program, BoundPyOpenCLProgram)
pytato_program.program = self.actx.transform_loopy_program(pytato_program
.program)
pytato_program = (pytato_program
.with_transformed_program(self
.actx
.transform_loopy_program))
self.program_cache[arg_id_to_descr] = CompiledFunction(
self.actx, pytato_program,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment