From 7bf764e00734830f23c5b185999ce525f7377bc5 Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Fri, 5 Aug 2022 16:54:40 -0500 Subject: [PATCH] Laid out the basic structure for PytatoCUDAGraphContext --- arraycontext/__init__.py | 5 +- arraycontext/impl/pytato/__init__.py | 226 +++++++++++++++++++++++++++ arraycontext/impl/pytato/compile.py | 107 +++++++++++++ arraycontext/pytest.py | 22 +++ test/test_arraycontext.py | 14 +- 5 files changed, 366 insertions(+), 8 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96..e1f5383 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -74,7 +74,8 @@ from .container.traversal import ( from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import (PytatoPyOpenCLArrayContext, - PytatoJAXArrayContext) + PytatoJAXArrayContext, + PytatoCUDAGraphArrayContext) from .impl.jax import EagerJAXArrayContext from .pytest import ( @@ -120,7 +121,7 @@ __all__ = ( "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", - "PytatoJAXArrayContext", + "PytatoJAXArrayContext", "PytatoCUDAGraphArrayContext", "EagerJAXArrayContext", "make_loopy_program", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8d7e042..2ef81fc 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -11,6 +11,7 @@ Following :mod:`pytato`-based array context are provided: .. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: PytatoJAXArrayContext +.. autoclass:: PytatoCUDAGraphContext Compiling a Python callable (Internal) @@ -789,4 +790,229 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # }}} +# {{{ PytatoJAXArrayContext + +class PytatoCUDAGraphArrayContext(_BasePytatoArrayContext): + """ + An arraycontext that uses :mod:`pytato` to represent the thawed state of + the arrays and compiles the expressions using + :class:`pytato.target.pycuda.CUDAGraphTarget`. + """ + + def __init__(self, + *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]] + = None) -> None: + """ + :arg compile_trace_callback: A function of three arguments + *(what, stage, ir)*, where *what* identifies the object + being compiled, *stage* is a string describing the compilation + pass, and *ir* is an object containing the intermediate + representation. This interface should be considered + unstable. + """ + import pytato as pt + from pycuda.gpuarray import GPUArray + super().__init__(compile_trace_callback=compile_trace_callback) + self.array_types = (pt.Array, GPUArray) + + @property + def _frozen_array_types(self) -> Tuple[Type, ...]: + from pycuda.gpuarray import GPUArray + return (GPUArray, ) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) + + return self._rec_map_container(_zeros_like, ary, default_scalar=0) + + def from_numpy(self, array): + import pycuda.gpuarray as gpuarray + import pytato as pt + + def _from_numpy(ary): + return pt.make_data_wrapper(gpuarray.to_gpu(ary)) + + return with_array_context( + self._rec_map_container(_from_numpy, array, (np.ndarray,)), + actx=self) + + def to_numpy(self, array): + def _to_numpy(ary): + return ary.get() + + return with_array_context( + self._rec_map_container(_to_numpy, self.freeze(array)), + actx=None) + + def freeze(self, array): + if np.isscalar(array): + return array + + import pytato as pt + + from pycuda.gpuarray import GPUArray + from arraycontext.container.traversal import rec_keyed_map_array_container + from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + + array_as_dict: Dict[str, Union[GPUArray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, GPUArray] = {} + key_to_pt_arrays: Dict[str, pt.Array] = {} + + def _record_leaf_ary_in_dict(key: Tuple[Any, ...], + ary: Union[GPUArray, pt.Array]) -> None: + key_str = "_ary" + _ary_container_key_stringifier(key) + array_as_dict[key_str] = ary + + rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) + + # {{{ remove any non pytato arrays from array_as_dict + + for key, subary in array_as_dict.items(): + if isinstance(subary, GPUArray): + key_to_frozen_subary[key] = subary.block_until_ready() + elif isinstance(subary, pt.DataWrapper): + # trivial freeze. + key_to_frozen_subary[key] = subary.data.block_until_ready() + elif isinstance(subary, pt.Array): + key_to_pt_arrays[key] = subary + else: + raise TypeError( + f"{type(self).__name__}.freeze invoked with an unsupported " + f"array type: got '{type(subary).__name__}', but expected one " + f"of {self.array_types}") + + # }}} + + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) + transformed_dag = self.transform_dag(pt_dict_of_named_arrays) + pt_prg = pt.generate_cudagraph(transformed_dag) + out_dict = pt_prg() + assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 + + key_to_frozen_subary = { + **key_to_frozen_subary, + **{k: v.block_until_ready() + for k, v in out_dict.items()} + } + + def _to_frozen(key: Tuple[Any, ...], ary) -> GPUArray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + + def thaw(self, array): + import pytato as pt + + def _thaw(ary): + return pt.make_data_wrapper(ary) + + return with_array_context( + self._rec_map_container(_thaw, array, self._frozen_array_types), + actx=self) + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + from .compile import LazilyCUDAGraphCompilingFunctionCaller + return LazilyCUDAGraphCompilingFunctionCaller(self, f) + + def tag(self, tags: ToTagSetConvertible, array): + from pycuda.gpuarray import GPUArray + + def _tag(ary): + if isinstance(ary, GPUArray): + return ary + else: + return ary.tagged(_preprocess_array_tags(tags)) + + return self._rec_map_container(_tag, array) + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + from pycuda.gpuarray import GPUArray + + def _tag_axis(ary): + if isinstance(ary, GPUArray): + return ary + else: + return ary.with_tagged_axis(iaxis, tags) + + return self._rec_map_container(_tag_axis, array) + + # }}} + + # {{{ compilation + + def call_loopy(self, program, **kwargs): + raise NotImplementedError( + "Calling loopy on GPUArray arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import pytato as pt + from pycuda.gpuarray import GPUArray + if arg_names is None: + arg_names = (None,) * len(args) + + def preprocess_arg(name, arg): + if isinstance(arg, GPUArray): + ary = self.thaw(arg) + elif isinstance(arg, pt.Array): + ary = arg + else: + raise TypeError( + f"{type(self).__name__}.einsum invoked with an unsupported " + f"array type: got '{type(arg).__name__}', but expected one " + f"of {self.array_types}") + + if name is not None: + # Tagging Placeholders with naming-related tags is pointless: + # They already have names. It's also counterproductive, as + # multiple placeholders with the same name that are not + # also the same object are not allowed, and this would produce + # a different Placeholder object of the same name. + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) + + return ary + + return pt.einsum(spec, *[ + preprocess_arg(name, arg) + for name, arg in zip(arg_names, args) + ]).tagged(_preprocess_array_tags(tagged)) + + def clone(self): + return type(self)() + +# }}} + # vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 3282328..1d9d42c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -2,6 +2,7 @@ .. autoclass:: BaseLazilyCompilingFunctionCaller .. autoclass:: LazilyPyOpenCLCompilingFunctionCaller .. autoclass:: LazilyJAXCompilingFunctionCaller +.. autoclass:: LazilyCUDAGraphCompilingFunctionCaller .. autoclass:: CompiledFunction .. autoclass:: FromArrayContextCompile """ @@ -33,6 +34,7 @@ from arraycontext.context import ArrayT from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.impl.pytato import (_BasePytatoArrayContext, PytatoJAXArrayContext, + PytatoCUDAGraphArrayContext, PytatoPyOpenCLArrayContext) from arraycontext.container.traversal import rec_keyed_map_array_container @@ -506,6 +508,49 @@ class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): return pytato_program, name_in_program_to_tags, name_in_program_to_axes +class LazilyCUDAGraphCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): + @property + def compiled_function_returning_array_container_class( + self) -> Type["CompiledFunction"]: + return CompiledCUDAGraphFunctionReturningArrayContainer + + @property + def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + return CompiledCUDAGraphFunctionReturningArray + + def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): + if prg_id is None: + prg_id = self.f + + self.actx._compile_trace_callback( + prg_id, "pre_transform_dag", dict_of_named_arrays) + + with ProcessLogger(logger, "transform_dag for '{prg_id}'"): + pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) + + self.actx._compile_trace_callback( + prg_id, "post_transform_dag", pt_dict_of_named_arrays) + + name_in_program_to_tags = { + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_axes = { + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()} + + self.actx._compile_trace_callback( + prg_id, "pre_generate_cudagraph", pt_dict_of_named_arrays) + + with ProcessLogger(logger, f"generate_cudagraph for '{prg_id}'"): + pytato_program = pt.generate_cudagraph( + pt_dict_of_named_arrays, + function_name=_prg_id_to_kernel_name(prg_id)) + + self.actx._compile_trace_callback( + prg_id, "post_generate_cudagraph", pytato_program) + + return pytato_program, name_in_program_to_tags, name_in_program_to_axes + def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} @@ -732,3 +777,65 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) return self.actx.thaw(out_dict[self.output_name]) + +@dataclass(frozen=True) +class CompiledCUDAGraphFunctionReturningArrayContainer(CompiledFunction): + """ + .. attribute:: output_id_to_name_in_program + + A mapping from output id to the name of + :class:`pytato.array.NamedArray` in + :attr:`CompiledFunction.pytato_program`. Output id is represented by + the key of a leaf array in the array container + :attr:`CompiledFunction.output_template`. + + .. attribute:: output_template + + An instance of :class:`arraycontext.ArrayContainer` that is the return + type of the callable. + """ + actx: PytatoCUDAGraphArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] + name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + output_template: ArrayContainer + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + out_dict = self.pytato_program(**input_kwargs_for_loopy) + + def to_output_template(keys, _): + return self.actx.thaw( + out_dict[self.output_id_to_name_in_program[keys]] + .block_until_ready() + ) + + return rec_keyed_map_array_container(to_output_template, + self.output_template) + + +@dataclass(frozen=True) +class CompiledJAXFunctionReturningArray(CompiledFunction): + """ + .. attribute:: output_name_in_program + + Name of the output array in the program. + """ + actx: PytatoCUDAGraphArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_tags: FrozenSet[Tag] + output_axes: Tuple[pt.Axis, ...] + output_name: str + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) + + return self.actx.thaw(out_dict[self.output_name]) \ No newline at end of file diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb49..2aae1ba 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -194,6 +194,27 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): def __str__(self): return "" +class _PytestPytatoCUDAGraphArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + @classmethod + def is_available(cls) -> bool: + try: + import pycuda # noqa: F401 + import pytato # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import PytatoCUDAGraphArrayContext + import pycuda.autoinit + return PytatoCUDAGraphArrayContext() + + def __str__(self): + return "" + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { @@ -202,6 +223,7 @@ _ARRAY_CONTEXT_FACTORY_REGISTRY: \ _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, + "pytato:cudagraph": _PytestPytatoCUDAGraphArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108..ae6b038 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -44,7 +44,8 @@ from arraycontext import ( # noqa: F401 from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory) + _PytestPytatoJaxArrayContextFactory, + _PytestPytatoCUDAGraphArrayContextFactory) import logging @@ -88,11 +89,12 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( pytest_generate_tests = pytest_generate_tests_for_array_contexts([ - _PyOpenCLArrayContextForTestsFactory, - _PyOpenCLArrayContextWithHostScalarsForTestsFactory, - _PytatoPyOpenCLArrayContextForTestsFactory, - _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory, + # _PyOpenCLArrayContextForTestsFactory, + # _PyOpenCLArrayContextWithHostScalarsForTestsFactory, + # _PytatoPyOpenCLArrayContextForTestsFactory, + # _PytestEagerJaxArrayContextFactory, + # _PytestPytatoJaxArrayContextFactory, + _PytestPytatoCUDAGraphArrayContextFactory ]) -- GitLab