Skip to content
Snippets Groups Projects
Commit d23600f0 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

avoid stateful update of tranforming a loopy t-unit inside a pt-program

parent 2eb6c64a
No related branches found
No related tags found
No related merge requests found
...@@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -122,9 +122,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
f"non-pytato array of type '{type(array)}'") f"non-pytato array of type '{type(array)}'")
t_unit = pt.generate_loopy(array, cl_device=self.queue.device) pt_prg = pt.generate_loopy(array, cl_device=self.queue.device)
t_unit = self.transform_loopy_program(t_unit) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
evt, (cl_array,) = t_unit(self.queue)
evt, (cl_array,) = pt_prg(self.queue)
evt.wait() evt.wait()
return cl_array.with_queue(None) return cl_array.with_queue(None)
......
...@@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller: ...@@ -178,6 +178,7 @@ class LazilyCompilingFunctionCaller:
:attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
The intermediary pytato DAG for *args* is memoized in *self*. The intermediary pytato DAG for *args* is memoized in *self*.
""" """
from pytato.target.loopy import BoundPyOpenCLProgram
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args) arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(args)
try: try:
...@@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller: ...@@ -221,9 +222,12 @@ class LazilyCompilingFunctionCaller:
pytato_program = pt.generate_loopy(dict_of_named_arrays, pytato_program = pt.generate_loopy(dict_of_named_arrays,
options={"return_dict": True}, options={"return_dict": True},
cl_device=self.actx.queue.device) cl_device=self.actx.queue.device)
assert isinstance(pytato_program, BoundPyOpenCLProgram)
pytato_program.program = self.actx.transform_loopy_program(pytato_program pytato_program = (pytato_program
.program) .with_transformed_program(self
.actx
.transform_loopy_program))
self.program_cache[arg_id_to_descr] = CompiledFunction( self.program_cache[arg_id_to_descr] = CompiledFunction(
self.actx, pytato_program, self.actx, pytato_program,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment