From 13eed85257d003a425719f9bd9fa2a57dd5c2354 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 25 Aug 2021 17:33:31 -0500 Subject: [PATCH] introduces FromArrayContextCompile tag --- arraycontext/impl/pytato/compile.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index cf8f3eb..a7e74bd 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -2,6 +2,7 @@ .. currentmodule:: arraycontext.impl.pytato.compile .. autoclass:: LazilyCompilingFunctionCaller .. autoclass:: CompiledFunction +.. autoclass:: FromArrayContextCompile """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -40,6 +41,7 @@ from pyrsistent import pmap, PMap import pyopencl.array as cla import pytato as pt import itertools +from pytools.tag import Tag from pytools import ProcessLogger @@ -47,6 +49,17 @@ import logging logger = logging.getLogger(__name__) +class FromArrayContextCompile(Tag): + """ + Tagged to the entrypoint kernel of every translation unit that is generated + by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + + Typically this tag serves as a branch condition in implementing a + specialized transform strategy for kernels compiled by + :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + """ + + # {{{ helper classes: AbstractInputDescriptor class AbstractInputDescriptor: @@ -245,6 +258,13 @@ class LazilyCompilingFunctionCaller: assert isinstance(pytato_program, BoundPyOpenCLProgram) with ProcessLogger(logger, "transform_loopy_program"): + + pytato_program = (pytato_program + .with_transformed_program( + lambda x: x.with_kernel( + x.default_entrypoint + .tagged(FromArrayContextCompile())))) + pytato_program = (pytato_program .with_transformed_program(self .actx -- GitLab