Skip to content
Snippets Groups Projects
Commit a567041a authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

make PytatoPyOpenCLArrayContext.transform_loopy_program an identity map

parent 32967894
No related branches found
No related tags found
No related merge requests found
......@@ -101,6 +101,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
return call_loopy(program, kwargs, entrypoint)
def freeze(self, array):
# TODO: This should store a cache of pytato DAG -> build pyopencl
# program instead of re-compiling the DAG for every freeze.
import pytato as pt
import pyopencl.array as cla
......@@ -110,8 +113,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
f"non-pytato array of type '{type(array)}'")
prg = pt.generate_loopy(array, cl_device=self.queue.device)
evt, (cl_array,) = prg(self.queue)
t_unit = pt.generate_loopy(array, cl_device=self.queue.device)
t_unit = self.transform_loopy_program(t_unit)
evt, (cl_array,) = t_unit(self.queue)
evt.wait()
return cl_array.with_queue(None)
......@@ -132,7 +136,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller
return LazilyCompilingFunctionCaller(self, f)
def transform_loopy_program(self, prg):
def transform_loopy_program(self, t_unit):
from warnings import warn
warn(
"Using arraycontext.PytatoPyOpenCLArrayContext.transform_loopy_program "
......@@ -144,56 +148,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
"to build on.",
DeprecationWarning, stacklevel=2)
from loopy.translation_unit import for_each_kernel
nwg = 48
nwi = (16, 2)
@for_each_kernel
def gridify(knl):
# {{{ Pattern matching inames
for insn in knl.instructions:
if isinstance(insn, lp.CallInstruction):
# must be a callable kernel, don't touch.
pass
elif isinstance(insn, lp.Assignment):
bigger_loop = None
smaller_loop = None
for iname in insn.within_inames:
if iname.startswith("iel"):
assert bigger_loop is None
bigger_loop = iname
if iname.startswith("idof"):
assert smaller_loop is None
smaller_loop = iname
if bigger_loop or smaller_loop:
assert bigger_loop is not None and smaller_loop is not None
else:
sorted_inames = sorted(tuple(insn.within_inames),
key=knl.get_constant_iname_length)
smaller_loop = sorted_inames[0]
bigger_loop = sorted_inames[1]
knl = lp.chunk_iname(knl, bigger_loop, nwg,
outer_tag="g.0")
knl = lp.split_iname(knl, f"{bigger_loop}_inner",
nwi[0], inner_tag="l.1")
knl = lp.split_iname(knl, smaller_loop,
nwi[1], inner_tag="l.0")
elif isinstance(insn, lp.BarrierInstruction):
pass
else:
raise NotImplementedError
# }}}
return knl
prg = lp.set_options(prg, "insert_additional_gbarriers")
return gridify(prg)
return t_unit
def tag(self, tags: Union[Sequence[Tag], Tag], array):
return array.tagged(tags)
......
......@@ -221,6 +221,9 @@ class LazilyCompilingFunctionCaller:
options={"return_dict": True},
cl_device=self.actx.queue.device)
pytato_program.program = self.actx.transform_loopy_program(pytato_program
.program)
self.program_cache[arg_id_to_descr] = CompiledFunction(
self.actx, pytato_program,
input_naming_map, output_naming_map,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment