Skip to content
Snippets Groups Projects
Commit f50736c1 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

defines PytatoPyOpenCLArrayContext.transform_dag

parent 78dc9221
No related branches found
No related tags found
No related merge requests found
......@@ -43,9 +43,12 @@ THE SOFTWARE.
from arraycontext.context import ArrayContext
import numpy as np
from typing import Any, Callable, Union, Sequence
from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
from pytools.tag import Tag
if TYPE_CHECKING:
import pytato
class PytatoPyOpenCLArrayContext(ArrayContext):
"""
......@@ -62,6 +65,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
to use the default allocator.
.. automethod:: __init__
.. automethod:: transform_dag
"""
def __init__(self, queue, allocator=None):
......@@ -116,6 +121,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
def freeze(self, array):
import pytato as pt
import pyopencl.array as cla
import loopy as lp
if isinstance(array, cla.Array):
return array.with_queue(None)
......@@ -134,20 +140,27 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
# }}}
from arraycontext.impl.pytato.utils import _normalize_pt_expr
normalized_expr, bound_arguments = _normalize_pt_expr(array)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
{"_actx_out": array})
normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays)
try:
pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError:
pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device)
pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr),
options=lp.Options(return_dict=True,
no_numpy=True),
cl_device=self.queue.device)
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg
assert len(pt_prg.bound_arguments) == 0
evt, (cl_array,) = pt_prg(self.queue, **bound_arguments)
evt, out_dict = pt_prg(self.queue, **bound_arguments)
evt.wait()
return cl_array.with_queue(None)
return out_dict["_actx_out"].with_queue(None)
def thaw(self, array):
import pytato as pt
......@@ -170,6 +183,20 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")
def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
*dag* (most likely to perform domain-specific optimizations). Every
:mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is
passed through this routine.
:arg dag: An instance of :class:`pytato.DictOfNamedArrays`
:returns: A transformed version of *dag*.
"""
return dag
def tag(self, tags: Union[Sequence[Tag], Tag], array):
return array.tagged(tags)
......
......@@ -226,7 +226,11 @@ class LazilyCompilingFunctionCaller:
outputs)
import loopy as lp
pytato_program = pt.generate_loopy(dict_of_named_arrays,
pt_dict_of_named_arrays = self.actx.transform_dag(
pt.make_dict_of_named_arrays(dict_of_named_arrays))
pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
options=lp.Options(
return_dict=True,
no_numpy=True),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment