From a567041ab9fb4b394178985503060e79494ffe18 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 28 Jun 2021 12:03:11 -0500 Subject: [PATCH] make PytatoPyOpenCLArrayContext.transform_loopy_program an identity map --- arraycontext/impl/pytato/__init__.py | 61 ++++------------------------ arraycontext/impl/pytato/compile.py | 3 ++ 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3e3339f..480cb1c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -101,6 +101,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return call_loopy(program, kwargs, entrypoint) def freeze(self, array): + # TODO: This should store a cache of pytato DAG -> build pyopencl + # program instead of re-compiling the DAG for every freeze. + import pytato as pt import pyopencl.array as cla @@ -110,8 +113,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") - prg = pt.generate_loopy(array, cl_device=self.queue.device) - evt, (cl_array,) = prg(self.queue) + t_unit = pt.generate_loopy(array, cl_device=self.queue.device) + t_unit = self.transform_loopy_program(t_unit) + evt, (cl_array,) = t_unit(self.queue) evt.wait() return cl_array.with_queue(None) @@ -132,7 +136,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller return LazilyCompilingFunctionCaller(self, f) - def transform_loopy_program(self, prg): + def transform_loopy_program(self, t_unit): from warnings import warn warn( "Using arraycontext.PytatoPyOpenCLArrayContext.transform_loopy_program " @@ -144,56 +148,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): "to build on.", DeprecationWarning, stacklevel=2) - from loopy.translation_unit import for_each_kernel - - nwg = 48 - nwi = (16, 2) - - @for_each_kernel - def gridify(knl): - # {{{ Pattern matching inames - - for insn in knl.instructions: - if isinstance(insn, lp.CallInstruction): - # must be a callable kernel, don't touch. - pass - elif isinstance(insn, lp.Assignment): - bigger_loop = None - smaller_loop = None - for iname in insn.within_inames: - if iname.startswith("iel"): - assert bigger_loop is None - bigger_loop = iname - if iname.startswith("idof"): - assert smaller_loop is None - smaller_loop = iname - - if bigger_loop or smaller_loop: - assert bigger_loop is not None and smaller_loop is not None - else: - sorted_inames = sorted(tuple(insn.within_inames), - key=knl.get_constant_iname_length) - smaller_loop = sorted_inames[0] - bigger_loop = sorted_inames[1] - - knl = lp.chunk_iname(knl, bigger_loop, nwg, - outer_tag="g.0") - knl = lp.split_iname(knl, f"{bigger_loop}_inner", - nwi[0], inner_tag="l.1") - knl = lp.split_iname(knl, smaller_loop, - nwi[1], inner_tag="l.0") - elif isinstance(insn, lp.BarrierInstruction): - pass - else: - raise NotImplementedError - - # }}} - - return knl - - prg = lp.set_options(prg, "insert_additional_gbarriers") - - return gridify(prg) + return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array): return array.tagged(tags) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index f169ea7..31baed1 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -221,6 +221,9 @@ class LazilyCompilingFunctionCaller: options={"return_dict": True}, cl_device=self.actx.queue.device) + pytato_program.program = self.actx.transform_loopy_program(pytato_program + .program) + self.program_cache[arg_id_to_descr] = CompiledFunction( self.actx, pytato_program, input_naming_map, output_naming_map, -- GitLab