Skip to content
Snippets Groups Projects
Commit 13eed852 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

introduces FromArrayContextCompile tag

parent 6eaf5186
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
.. currentmodule:: arraycontext.impl.pytato.compile .. currentmodule:: arraycontext.impl.pytato.compile
.. autoclass:: LazilyCompilingFunctionCaller .. autoclass:: LazilyCompilingFunctionCaller
.. autoclass:: CompiledFunction .. autoclass:: CompiledFunction
.. autoclass:: FromArrayContextCompile
""" """
__copyright__ = """ __copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees Copyright (C) 2020-1 University of Illinois Board of Trustees
...@@ -40,6 +41,7 @@ from pyrsistent import pmap, PMap ...@@ -40,6 +41,7 @@ from pyrsistent import pmap, PMap
import pyopencl.array as cla import pyopencl.array as cla
import pytato as pt import pytato as pt
import itertools import itertools
from pytools.tag import Tag
from pytools import ProcessLogger from pytools import ProcessLogger
...@@ -47,6 +49,17 @@ import logging ...@@ -47,6 +49,17 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FromArrayContextCompile(Tag):
"""
Tagged to the entrypoint kernel of every translation unit that is generated
by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
Typically this tag serves as a branch condition in implementing a
specialized transform strategy for kernels compiled by
:meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
"""
# {{{ helper classes: AbstractInputDescriptor # {{{ helper classes: AbstractInputDescriptor
class AbstractInputDescriptor: class AbstractInputDescriptor:
...@@ -245,6 +258,13 @@ class LazilyCompilingFunctionCaller: ...@@ -245,6 +258,13 @@ class LazilyCompilingFunctionCaller:
assert isinstance(pytato_program, BoundPyOpenCLProgram) assert isinstance(pytato_program, BoundPyOpenCLProgram)
with ProcessLogger(logger, "transform_loopy_program"): with ProcessLogger(logger, "transform_loopy_program"):
pytato_program = (pytato_program
.with_transformed_program(
lambda x: x.with_kernel(
x.default_entrypoint
.tagged(FromArrayContextCompile()))))
pytato_program = (pytato_program pytato_program = (pytato_program
.with_transformed_program(self .with_transformed_program(self
.actx .actx
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment