diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3e3339f40d4681d3ef52c47a94bf4b53267ebb9c..480cb1c58f84c537046e337eafc34684fd751be4 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -101,6 +101,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return call_loopy(program, kwargs, entrypoint) def freeze(self, array): + # TODO: This should store a cache of pytato DAG -> build pyopencl + # program instead of re-compiling the DAG for every freeze. + import pytato as pt import pyopencl.array as cla @@ -110,8 +113,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") - prg = pt.generate_loopy(array, cl_device=self.queue.device) - evt, (cl_array,) = prg(self.queue) + t_unit = pt.generate_loopy(array, cl_device=self.queue.device) + t_unit = self.transform_loopy_program(t_unit) + evt, (cl_array,) = t_unit(self.queue) evt.wait() return cl_array.with_queue(None) @@ -132,7 +136,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller return LazilyCompilingFunctionCaller(self, f) - def transform_loopy_program(self, prg): + def transform_loopy_program(self, t_unit): from warnings import warn warn( "Using arraycontext.PytatoPyOpenCLArrayContext.transform_loopy_program " @@ -144,56 +148,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): "to build on.", DeprecationWarning, stacklevel=2) - from loopy.translation_unit import for_each_kernel - - nwg = 48 - nwi = (16, 2) - - @for_each_kernel - def gridify(knl): - # {{{ Pattern matching inames - - for insn in knl.instructions: - if isinstance(insn, lp.CallInstruction): - # must be a callable kernel, don't touch. - pass - elif isinstance(insn, lp.Assignment): - bigger_loop = None - smaller_loop = None - for iname in insn.within_inames: - if iname.startswith("iel"): - assert bigger_loop is None - bigger_loop = iname - if iname.startswith("idof"): - assert smaller_loop is None - smaller_loop = iname - - if bigger_loop or smaller_loop: - assert bigger_loop is not None and smaller_loop is not None - else: - sorted_inames = sorted(tuple(insn.within_inames), - key=knl.get_constant_iname_length) - smaller_loop = sorted_inames[0] - bigger_loop = sorted_inames[1] - - knl = lp.chunk_iname(knl, bigger_loop, nwg, - outer_tag="g.0") - knl = lp.split_iname(knl, f"{bigger_loop}_inner", - nwi[0], inner_tag="l.1") - knl = lp.split_iname(knl, smaller_loop, - nwi[1], inner_tag="l.0") - elif isinstance(insn, lp.BarrierInstruction): - pass - else: - raise NotImplementedError - - # }}} - - return knl - - prg = lp.set_options(prg, "insert_additional_gbarriers") - - return gridify(prg) + return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array): return array.tagged(tags) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index f169ea7674b13d5dc86d4e20802a3312b6a66e18..31baed1a13c6ef03c18751245b8b14960b8a2c7d 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -221,6 +221,9 @@ class LazilyCompilingFunctionCaller: options={"return_dict": True}, cl_device=self.actx.queue.device) + pytato_program.program = self.actx.transform_loopy_program(pytato_program + .program) + self.program_cache[arg_id_to_descr] = CompiledFunction( self.actx, pytato_program, input_naming_map, output_naming_map,