Skip to content
Snippets Groups Projects
Commit 91a94fc3 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

make the frozen type of PytatoPyOpenCLArrayContext to be TaggableCLArrays

parent c8427b92
No related branches found
No related tags found
No related merge requests found
...@@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
self.allocator = allocator self.allocator = allocator
self.array_types = (pt.Array, ) self.array_types = (pt.Array, )
self._freeze_prg_cache = {} self._freeze_prg_cache = {}
self._dag_transform_cache = {}
# unused, but necessary to keep the context alive # unused, but necessary to keep the context alive
self.context = self.queue.context self.context = self.queue.context
...@@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
return cl_array.get(queue=self.queue) return cl_array.get(queue=self.queue)
def call_loopy(self, program, **kwargs): def call_loopy(self, program, **kwargs):
import pyopencl.array as cla from pytato.scalar_expr import SCALAR_CLASSES
from pytato.loopy import call_loopy from pytato.loopy import call_loopy
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
entrypoint = program.default_entrypoint.name entrypoint = program.default_entrypoint.name
# thaw frozen arrays # {{{ preprocess args
kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg)
for kw, arg in kwargs.items()} processed_kwargs = {}
for kw, arg in sorted(kwargs.items()):
if isinstance(arg, self.array_types + SCALAR_CLASSES):
pass
elif isinstance(arg, TaggableCLArray):
arg = self.thaw(arg)
else:
raise ValueError(f"call_loopy argument '{kw}' expected to be an"
" instance of 'pytato.Array', 'Number' or"
f"'TaggableCLArray', got '{type(arg)}'")
processed_kwargs[kw] = arg
# }}}
return call_loopy(program, kwargs, entrypoint) return call_loopy(program, processed_kwargs, entrypoint)
def freeze(self, array): def freeze(self, array):
import pytato as pt import pytato as pt
import pyopencl.array as cla import pyopencl.array as cla
import loopy as lp import loopy as lp
from arraycontext.impl.pytato.utils import (_normalize_pt_expr,
get_cl_axes_from_pt_axes)
from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
TaggableCLArray)
if isinstance(array, cla.Array): if isinstance(array, TaggableCLArray):
return array.with_queue(None) return array.with_queue(None)
if isinstance(array, cla.Array):
from warnings import warn
warn("Freezing pyopencl.array.Array will be deprecated in 2023."
" Use `to_tagged_cl_array` to convert the array to"
" TaggableCLArray", DeprecationWarning, stacklevel=2)
return to_tagged_cl_array(array.with_queue(None),
axes=None,
tags=frozenset())
if isinstance(array, pt.DataWrapper):
# trivial freeze.
return to_tagged_cl_array(array.data.with_queue(None),
axes=get_cl_axes_from_pt_axes(array.axes),
tags=array.tags)
if not isinstance(array, pt.Array): if not isinstance(array, pt.Array):
raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
f"non-pytato array of type '{type(array)}'") f"non-pytato array of type '{type(array)}'")
...@@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
# {{{ early exit for 0-sized arrays # {{{ early exit for 0-sized arrays
if array.size == 0: if array.size == 0:
return cla.empty(self.queue.context, return to_tagged_cl_array(
shape=array.shape, cla.empty(self.queue.context,
dtype=array.dtype, shape=array.shape,
allocator=self.allocator) dtype=array.dtype,
allocator=self.allocator),
get_cl_axes_from_pt_axes(array.axes),
array.tags)
# }}} # }}}
from arraycontext.impl.pytato.utils import _normalize_pt_expr
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
{"_actx_out": array}) {"_actx_out": array})
...@@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
try: try:
pt_prg = self._freeze_prg_cache[normalized_expr] pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError: except KeyError:
pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), if normalized_expr in self._dag_transform_cache:
transformed_dag = self._dag_transform_cache[normalized_expr]
else:
transformed_dag = self.transform_dag(normalized_expr)
self._dag_transform_cache[normalized_expr] = transformed_dag
pt_prg = pt.generate_loopy(transformed_dag,
options=lp.Options(return_dict=True, options=lp.Options(return_dict=True,
no_numpy=True), no_numpy=True),
cl_device=self.queue.device) cl_device=self.queue.device)
...@@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
evt, out_dict = pt_prg(self.queue, **bound_arguments) evt, out_dict = pt_prg(self.queue, **bound_arguments)
evt.wait() evt.wait()
return out_dict["_actx_out"].with_queue(None) return to_tagged_cl_array(
out_dict["_actx_out"].with_queue(None),
get_cl_axes_from_pt_axes(
self._dag_transform_cache[normalized_expr]["_actx_out"].expr.axes),
array.tags)
def thaw(self, array): def thaw(self, array):
import pytato as pt import pytato as pt
import pyopencl.array as cla from .utils import get_pt_axes_from_cl_axes
from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
if not isinstance(array, cla.Array): to_tagged_cl_array)
raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got " import pyopencl.array as cl_array
f"{type(array)}")
if isinstance(array, TaggableCLArray):
return pt.make_data_wrapper(array.with_queue(self.queue)) pass
elif isinstance(array, cl_array.Array):
array = to_tagged_cl_array(array, axes=None, tags=frozenset())
else:
raise TypeError("PytatoPyOpenCLArrayContext.thaw expects "
"'TaggableCLArray' or 'cl.array.Array' got "
f"{type(array)}.")
return pt.make_data_wrapper(array.with_queue(self.queue),
axes=get_pt_axes_from_cl_axes(array.axes),
tags=array.tags)
# }}} # }}}
...@@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext): ...@@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
def einsum(self, spec, *args, arg_names=None, tagged=()): def einsum(self, spec, *args, arg_names=None, tagged=()):
import pyopencl.array as cla import pyopencl.array as cla
import pytato as pt import pytato as pt
from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
to_tagged_cl_array)
if arg_names is None: if arg_names is None:
arg_names = (None,) * len(args) arg_names = (None,) * len(args)
def preprocess_arg(name, arg): def preprocess_arg(name, arg):
if isinstance(arg, cla.Array): if isinstance(arg, TaggableCLArray):
ary = self.thaw(arg) ary = self.thaw(arg)
elif isinstance(arg, cla.Array):
from warnings import warn
warn("Passing pyopencl.array.Array to einsum will be "
"deprecated in 2023."
" Use `to_tagged_cl_array` to convert the array to"
" TaggableCLArray.", DeprecationWarning, stacklevel=2)
ary = self.thaw(to_tagged_cl_array(arg,
axes=None,
tags=frozenset()))
else: else:
assert isinstance(arg, pt.Array) assert isinstance(arg, pt.Array)
ary = arg ary = arg
......
...@@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container ...@@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container
import abc import abc
import numpy as np import numpy as np
from typing import Any, Callable, Tuple, Dict, Mapping from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pyrsistent import pmap, PMap from pyrsistent import pmap, PMap
...@@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name): ...@@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
elif is_array_container_type(arg.__class__): elif is_array_container_type(arg.__class__):
def _rec_to_placeholder(keys, ary): def _rec_to_placeholder(keys, ary):
name = arg_id_to_name[(kw,) + keys] name = arg_id_to_name[(kw,) + keys]
return pt.make_placeholder(name, ary.shape, ary.dtype) return pt.make_placeholder(name,
ary.shape,
ary.dtype,
axes=ary.axes,
tags=ary.tags)
return rec_keyed_map_array_container(_rec_to_placeholder, arg) return rec_keyed_map_array_container(_rec_to_placeholder, arg)
else: else:
...@@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller: ...@@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller:
with ProcessLogger(logger, "transform_dag"): with ProcessLogger(logger, "transform_dag"):
pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
name_in_program_to_tags = {
name: out.tags
for name, out in pt_dict_of_named_arrays._data.items()}
name_in_program_to_axes = {
name: out.axes
for name, out in pt_dict_of_named_arrays._data.items()}
with ProcessLogger(logger, "generate_loopy"): with ProcessLogger(logger, "generate_loopy"):
pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
options=lp.Options( options=lp.Options(
...@@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller: ...@@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller:
.actx .actx
.transform_loopy_program)) .transform_loopy_program))
return pytato_program return pytato_program, name_in_program_to_tags, name_in_program_to_axes
def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
input_id_to_name_in_program, output_id_to_name_in_program, input_id_to_name_in_program, output_id_to_name_in_program,
...@@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller: ...@@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller:
output_id = "_pt_out" output_id = "_pt_out"
dict_of_named_arrays = pt.make_dict_of_named_arrays( dict_of_named_arrays = pt.make_dict_of_named_arrays(
{output_id: ary_or_dict_of_named_arrays}) {output_id: ary_or_dict_of_named_arrays})
pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays) pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
self._dag_to_transformed_loopy_prg(dict_of_named_arrays))
return CompiledFunctionReturningArray( return CompiledFunctionReturningArray(
self.actx, pytato_program, self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program, input_id_to_name_in_program=input_id_to_name_in_program,
output_name_in_program=output_id) output_tags=name_in_program_to_tags[output_id],
output_axes=name_in_program_to_axes[output_id],
output_name=output_id)
elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
pytato_program = self._dag_to_transformed_loopy_prg( pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
ary_or_dict_of_named_arrays) self._dag_to_transformed_loopy_prg(ary_or_dict_of_named_arrays))
return CompiledFunctionReturningArrayContainer( return CompiledFunctionReturningArrayContainer(
self.actx, pytato_program, self.actx, pytato_program,
input_id_to_name_in_program=input_id_to_name_in_program, input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program,
name_in_program_to_tags=name_in_program_to_tags,
name_in_program_to_axes=name_in_program_to_axes,
output_template=output_template) output_template=output_template)
else: else:
raise NotImplementedError(type(ary_or_dict_of_named_arrays)) raise NotImplementedError(type(ary_or_dict_of_named_arrays))
...@@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller: ...@@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller:
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
input_kwargs_for_loopy = {} input_kwargs_for_loopy = {}
for arg_id, arg in arg_id_to_arg.items(): for arg_id, arg in arg_id_to_arg.items():
...@@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): ...@@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
elif isinstance(arg, pt.array.DataWrapper): elif isinstance(arg, pt.array.DataWrapper):
# got a Datwwrapper => simply gets its data # got a Datwwrapper => simply gets its data
arg = arg.data arg = arg.data
elif isinstance(arg, cla.Array): elif isinstance(arg, TaggableCLArray):
# got a frozen array => do nothing # got a frozen array => do nothing
pass pass
elif isinstance(arg, pt.Array): elif isinstance(arg, pt.Array):
...@@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): ...@@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
pytato_program: pt.target.BoundProgram pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]]
output_template: ArrayContainer output_template: ArrayContainer
def __call__(self, arg_id_to_arg) -> ArrayContainer: def __call__(self, arg_id_to_arg) -> ArrayContainer:
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
from .utils import get_cl_axes_from_pt_axes
input_kwargs_for_loopy = _args_to_cl_buffers( input_kwargs_for_loopy = _args_to_cl_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg) self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
...@@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction): ...@@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
evt.wait() evt.wait()
def to_output_template(keys, _): def to_output_template(keys, _):
return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]]) name_in_program = self.output_id_to_name_in_program[keys]
return self.actx.thaw(to_tagged_cl_array(
out_dict[name_in_program],
axes=get_cl_axes_from_pt_axes(
self.name_in_program_to_axes[name_in_program]),
tags=self.name_in_program_to_tags[name_in_program]))
return rec_keyed_map_array_container(to_output_template, return rec_keyed_map_array_container(to_output_template,
self.output_template) self.output_template)
...@@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction): ...@@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction):
actx: PytatoPyOpenCLArrayContext actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
output_tags: FrozenSet[Tag]
output_axes: Tuple[pt.Axis, ...]
output_name: str output_name: str
def __call__(self, arg_id_to_arg) -> ArrayContainer: def __call__(self, arg_id_to_arg) -> ArrayContainer:
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
from .utils import get_cl_axes_from_pt_axes
input_kwargs_for_loopy = _args_to_cl_buffers( input_kwargs_for_loopy = _args_to_cl_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg) self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
...@@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction): ...@@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction):
# running out of memory. This mitigates that risk a bit, for now. # running out of memory. This mitigates that risk a bit, for now.
evt.wait() evt.wait()
return self.actx.thaw(out_dict[self.output_name]) return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
axes=get_cl_axes_from_pt_axes(
self.output_axes),
tags=self.output_tags))
...@@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): ...@@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
shape=tuple(self.rec(s) if isinstance(s, Array) else s shape=tuple(self.rec(s) if isinstance(s, Array) else s
for s in expr.shape), for s in expr.shape),
dtype=expr.dtype, dtype=expr.dtype,
axes=expr.axes,
tags=expr.tags) tags=expr.tags)
def map_size_param(self, expr: SizeParam) -> Array: def map_size_param(self, expr: SizeParam) -> Array:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment