diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index cf8f3eb9b5becda8a0893cb8564373b9cf0483a6..a7e74bd3a1628897b2f86c59a968313cfbd5572f 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -2,6 +2,7 @@ .. currentmodule:: arraycontext.impl.pytato.compile .. autoclass:: LazilyCompilingFunctionCaller .. autoclass:: CompiledFunction +.. autoclass:: FromArrayContextCompile """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -40,6 +41,7 @@ from pyrsistent import pmap, PMap import pyopencl.array as cla import pytato as pt import itertools +from pytools.tag import Tag from pytools import ProcessLogger @@ -47,6 +49,17 @@ import logging logger = logging.getLogger(__name__) +class FromArrayContextCompile(Tag): + """ + Tagged to the entrypoint kernel of every translation unit that is generated + by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + + Typically this tag serves as a branch condition in implementing a + specialized transform strategy for kernels compiled by + :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + """ + + # {{{ helper classes: AbstractInputDescriptor class AbstractInputDescriptor: @@ -245,6 +258,13 @@ class LazilyCompilingFunctionCaller: assert isinstance(pytato_program, BoundPyOpenCLProgram) with ProcessLogger(logger, "transform_loopy_program"): + + pytato_program = (pytato_program + .with_transformed_program( + lambda x: x.with_kernel( + x.default_entrypoint + .tagged(FromArrayContextCompile())))) + pytato_program = (pytato_program .with_transformed_program(self .actx