diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3f7e7601d9f9995b2a2fd7044b99a0818771864a..9dd805a6a955c079396ea3f54e1c5af956b59cb4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -21,6 +21,7 @@ Python 3 Nvidia Titan V: export PYOPENCL_TEST=nvi:titan build_py_project_in_venv pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html + pip install git+https://gitlab.tiker.net/kaushikcfd/pycuda.git@cudagraph#egg=pycuda test_py_project tags: diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96c5f661c40aaf92d08f6aa5851daa6ddf4..e1f5383de6429f4e22b6e8102d967df61c98d09d 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -74,7 +74,8 @@ from .container.traversal import ( from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import (PytatoPyOpenCLArrayContext, - PytatoJAXArrayContext) + PytatoJAXArrayContext, + PytatoCUDAGraphArrayContext) from .impl.jax import EagerJAXArrayContext from .pytest import ( @@ -120,7 +121,7 @@ __all__ = ( "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", - "PytatoJAXArrayContext", + "PytatoJAXArrayContext", "PytatoCUDAGraphArrayContext", "EagerJAXArrayContext", "make_loopy_program", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8d7e0426b87263483a68aa771a9e10c70cde094a..94d6045cf5ded936ca919b847cc15fc62896bc2d 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -11,6 +11,7 @@ Following :mod:`pytato`-based array context are provided: .. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: PytatoJAXArrayContext +.. autoclass:: PytatoCUDAGraphContext Compiling a Python callable (Internal) @@ -423,7 +424,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): self._dag_transform_cache[normalized_expr]) assert len(pt_prg.bound_arguments) == 0 - evt, out_dict = pt_prg(self.queue, **bound_arguments) + evt, out_dict = pt_prg(self.queue, + allocator=self.allocator, + **bound_arguments) evt.wait() assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 @@ -789,4 +792,242 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # }}} +# {{{ PytatoCUDAGraphArrayContext + + +class PytatoCUDAGraphArrayContext(_BasePytatoArrayContext): + """ + An arraycontext that uses :mod:`pytato` to represent the thawed state of + the arrays and compiles the expressions using + :class:`pytato.target.pycuda.CUDAGraphTarget`. + """ + + def __init__(self, + *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]] + = None, + allocator=None) -> None: + """ + :arg compile_trace_callback: A function of three arguments + *(what, stage, ir)*, where *what* identifies the object + being compiled, *stage* is a string describing the compilation + pass, and *ir* is an object containing the intermediate + representation. This interface should be considered + unstable. + """ + import pytato as pt + from pycuda.gpuarray import GPUArray + import pycuda + super().__init__(compile_trace_callback=compile_trace_callback) + self.array_types = (pt.Array, GPUArray) + if allocator is None: + self.allocator = pycuda.driver.mem_alloc + from warnings import warn + warn("PytatoCUDAGraphArrayContext created without an allocator on a GPU. " + "This can lead to high numbers of memory allocations. " + "Please consider using a pycuda.autoinit. " + "Run with allocator=False to disable this warning.") + else: + self.allocator = allocator + + @property + def _frozen_array_types(self) -> Tuple[Type, ...]: + from pycuda.gpuarray import GPUArray + return (GPUArray, ) + + def _rec_map_container( + self, func: Callable[[Array], Array], array: ArrayOrContainer, + allowed_types: Optional[Tuple[type, ...]] = None, *, + default_scalar: Optional[ScalarLike] = None, + strict: bool = False) -> ArrayOrContainer: + if allowed_types is None: + allowed_types = self.array_types + + def _wrapper(ary): + if isinstance(ary, allowed_types): + return func(ary) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + # {{{ ArrayContext interface + + def zeros_like(self, ary): + def _zeros_like(array): + return self.zeros(array.shape, array.dtype) + + return self._rec_map_container(_zeros_like, ary, default_scalar=0) + + def from_numpy(self, array): + import pycuda.gpuarray as gpuarray + import pytato as pt + + def _from_numpy(ary): + return pt.make_data_wrapper(gpuarray.to_gpu(ary)) + + return with_array_context( + self._rec_map_container(_from_numpy, array, (np.ndarray,), strict=True), + actx=self) + + def to_numpy(self, array): + def _to_numpy(ary): + return ary.get() + + return with_array_context( + self._rec_map_container(_to_numpy, self.freeze(array)), + actx=None) + + def freeze(self, array): + if np.isscalar(array): + return array + + import pytato as pt + + from pycuda.gpuarray import GPUArray + from arraycontext.container.traversal import rec_keyed_map_array_container + from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + + array_as_dict: Dict[str, Union[GPUArray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, GPUArray] = {} + key_to_pt_arrays: Dict[str, pt.Array] = {} + + def _record_leaf_ary_in_dict(key: Tuple[Any, ...], + ary: Union[GPUArray, pt.Array]) -> None: + key_str = "_ary" + _ary_container_key_stringifier(key) + array_as_dict[key_str] = ary + + rec_keyed_map_array_container(_record_leaf_ary_in_dict, array) + + # {{{ remove any non pytato arrays from array_as_dict + + for key, subary in array_as_dict.items(): + if isinstance(subary, GPUArray): + key_to_frozen_subary[key] = subary + elif isinstance(subary, pt.DataWrapper): + # trivial freeze. + import pycuda.gpuarray as gpuarray + key_to_frozen_subary[key] = gpuarray.to_gpu(subary.data) + elif isinstance(subary, pt.Array): + key_to_pt_arrays[key] = subary + else: + raise TypeError( + f"{type(self).__name__}.freeze invoked with an unsupported " + f"array type: got '{type(subary).__name__}', but expected one " + f"of {self.array_types}") + + # }}} + + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) + transformed_dag = self.transform_dag(pt_dict_of_named_arrays) + pt_prg = pt.generate_cudagraph(transformed_dag) + out_dict = pt_prg() + assert len(set(out_dict) & set(key_to_frozen_subary)) == 0 + + key_to_frozen_subary = { + **key_to_frozen_subary, + **{k: v + for k, v in out_dict.items()} + } + + def _to_frozen(key: Tuple[Any, ...], ary) -> GPUArray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + + def thaw(self, array): + import pytato as pt + + def _thaw(ary): + return pt.make_data_wrapper(ary) + + return with_array_context( + self._rec_map_container(_thaw, array, self._frozen_array_types), + actx=self) + + def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + from .compile import LazilyCUDAGraphCompilingFunctionCaller + return LazilyCUDAGraphCompilingFunctionCaller(self, f) + + def tag(self, tags: ToTagSetConvertible, array): + from pycuda.gpuarray import GPUArray + + def _tag(ary): + if isinstance(ary, GPUArray): + return ary + else: + return ary.tagged(_preprocess_array_tags(tags)) + + return self._rec_map_container(_tag, array) + + def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): + from pycuda.gpuarray import GPUArray + + def _tag_axis(ary): + if isinstance(ary, GPUArray): + return ary + else: + return ary.with_tagged_axis(iaxis, tags) + + return self._rec_map_container(_tag_axis, array) + + # }}} + + # {{{ compilation + + def call_loopy(self, program, **kwargs): + raise NotImplementedError( + "Calling loopy on GPUArray arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import pytato as pt + from pycuda.gpuarray import GPUArray + if arg_names is None: + arg_names = (None,) * len(args) + + def preprocess_arg(name, arg): + if isinstance(arg, GPUArray): + ary = self.thaw(arg) + elif isinstance(arg, pt.Array): + ary = arg + else: + raise TypeError( + f"{type(self).__name__}.einsum invoked with an unsupported " + f"array type: got '{type(arg).__name__}', but expected one " + f"of {self.array_types}") + + if name is not None: + # Tagging Placeholders with naming-related tags is pointless: + # They already have names. It's also counterproductive, as + # multiple placeholders with the same name that are not + # also the same object are not allowed, and this would produce + # a different Placeholder object of the same name. + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) + + return ary + + return pt.einsum(spec, *[ + preprocess_arg(name, arg) + for name, arg in zip(arg_names, args) + ]).tagged(_preprocess_array_tags(tagged)) + + def clone(self): + return type(self)() + +# }}} + # vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 32823286677d836872879b607589f740fe355e88..bb1bed487e5014cbe06da84dac64acd68ec3176b 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -2,6 +2,7 @@ .. autoclass:: BaseLazilyCompilingFunctionCaller .. autoclass:: LazilyPyOpenCLCompilingFunctionCaller .. autoclass:: LazilyJAXCompilingFunctionCaller +.. autoclass:: LazilyCUDAGraphCompilingFunctionCaller .. autoclass:: CompiledFunction .. autoclass:: FromArrayContextCompile """ @@ -33,6 +34,7 @@ from arraycontext.context import ArrayT from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.impl.pytato import (_BasePytatoArrayContext, PytatoJAXArrayContext, + PytatoCUDAGraphArrayContext, PytatoPyOpenCLArrayContext) from arraycontext.container.traversal import rec_keyed_map_array_container @@ -105,6 +107,8 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # }}} +# {{{ utilities + def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an @@ -236,6 +240,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): else: raise NotImplementedError(type(arg)) +# }}} + + +# {{{ BaseLazilyCompilingFunctionCaller @dataclass class BaseLazilyCompilingFunctionCaller: @@ -366,6 +374,10 @@ class BaseLazilyCompilingFunctionCaller: self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) +# }}} + + +# {{{ LazilyPyOpenCLCompilingFunctionCaller class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property @@ -440,6 +452,8 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): return pytato_program, name_in_program_to_tags, name_in_program_to_axes +# }}} + # {{{ preserve back compat @@ -461,6 +475,8 @@ class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): # }}} +# {{{ LazilyJAXCompilingFunctionCaller + class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( @@ -506,6 +522,50 @@ class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): return pytato_program, name_in_program_to_tags, name_in_program_to_axes +class LazilyCUDAGraphCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): + @property + def compiled_function_returning_array_container_class( + self) -> Type["CompiledFunction"]: + return CompiledCUDAGraphFunctionReturningArrayContainer + + @property + def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]: + return CompiledCUDAGraphFunctionReturningArray + + def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): + if prg_id is None: + prg_id = self.f + + self.actx._compile_trace_callback( + prg_id, "pre_transform_dag", dict_of_named_arrays) + + with ProcessLogger(logger, "transform_dag for '{prg_id}'"): + pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays) + + self.actx._compile_trace_callback( + prg_id, "post_transform_dag", pt_dict_of_named_arrays) + + name_in_program_to_tags = { + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_axes = { + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()} + + self.actx._compile_trace_callback( + prg_id, "pre_generate_cudagraph", pt_dict_of_named_arrays) + + with ProcessLogger(logger, f"generate_cudagraph for '{prg_id}'"): + pytato_program = pt.generate_cudagraph( + pt_dict_of_named_arrays, + function_name=_prg_id_to_kernel_name(prg_id)) + + self.actx._compile_trace_callback( + prg_id, "post_generate_cudagraph", pytato_program) + + return pytato_program, name_in_program_to_tags, name_in_program_to_axes + + def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} @@ -513,10 +573,14 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): if np.isscalar(arg): if isinstance(actx, PytatoPyOpenCLArrayContext): import pyopencl.array as cla - arg = cla.to_device(actx.queue, np.array(arg)) + arg = cla.to_device(actx.queue, np.array(arg), + allocator=actx.allocator) elif isinstance(actx, PytatoJAXArrayContext): import jax arg = jax.device_put(arg) + elif isinstance(actx, PytatoCUDAGraphArrayContext): + import pycuda.gpuarray as gpuarray + arg = gpuarray.to_gpu(np.array(arg)) else: raise NotImplementedError(type(actx)) @@ -553,6 +617,10 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg) +# }}} + + +# {{{ compiled function class CompiledFunction(abc.ABC): """ @@ -582,6 +650,10 @@ class CompiledFunction(abc.ABC): """ pass +# }}} + + +# {{{ copmiled pyopencl function @dataclass(frozen=True) class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): @@ -670,7 +742,10 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): self.output_axes), tags=self.output_tags)) +# }}} + +# {{{ comiled jax function @dataclass(frozen=True) class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ @@ -732,3 +807,69 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) return self.actx.thaw(out_dict[self.output_name]) + + +@dataclass(frozen=True) +class CompiledCUDAGraphFunctionReturningArrayContainer(CompiledFunction): + """ + .. attribute:: output_id_to_name_in_program + + A mapping from output id to the name of + :class:`pytato.array.NamedArray` in + :attr:`CompiledFunction.pytato_program`. Output id is represented by + the key of a leaf array in the array container + :attr:`CompiledFunction.output_template`. + + .. attribute:: output_template + + An instance of :class:`arraycontext.ArrayContainer` that is the return + type of the callable. + """ + actx: PytatoCUDAGraphArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] + name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]] + output_template: ArrayContainer + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + out_dict = self.pytato_program(**input_kwargs_for_loopy) + + def to_output_template(keys, _): + return self.actx.thaw( + out_dict[self.output_id_to_name_in_program[keys]] + ) + + return rec_keyed_map_array_container(to_output_template, + self.output_template) + + +@dataclass(frozen=True) +class CompiledCUDAGraphFunctionReturningArray(CompiledFunction): + """ + .. attribute:: output_name_in_program + + Name of the output array in the program. + """ + actx: PytatoCUDAGraphArrayContext + pytato_program: pt.target.BoundProgram + input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + output_tags: FrozenSet[Tag] + output_axes: Tuple[pt.Axis, ...] + output_name: str + + def __call__(self, arg_id_to_arg) -> ArrayContainer: + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + + evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) + + import pycuda.gpuarray as gpuarray + return self.actx.thaw(gpuarray.to_gpu(out_dict[self.output_name])) +# }}} + +# vim: foldmethod=marker diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb4973920ff67ed772989695d7862b8c4021c..964dd4b7623626a0ae4f74fa47c072c4390890e7 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -195,6 +195,28 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): return "" +class _PytestPytatoCUDAGraphArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + pass + + @classmethod + def is_available(cls) -> bool: + try: + import pycuda # noqa: F401 + import pytato # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import PytatoCUDAGraphArrayContext + import pycuda.autoinit # noqa + return PytatoCUDAGraphArrayContext() + + def __str__(self): + return "" + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -202,6 +224,7 @@ _ARRAY_CONTEXT_FACTORY_REGISTRY: \ _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, + "pytato:cudagraph": _PytestPytatoCUDAGraphArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, } diff --git a/requirements.txt b/requirements.txt index a4cb402530521fa040fbd45c4f50607976eed51c..43306fd6568d6225e54259a529d8fe7c7abc42a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,6 @@ git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://gitlab.tiker.net/inducer/loopy.git@pycuda_tgt#egg=loopy +git+https://gitlab.tiker.net/kaushikcfd/pycuda.git@cudagraph#egg-pycuda +git+https://gitlab.tiker.net/kaushikcfd/pytato.git@cudagraph#egg-pytato diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e5d6cb63b083f2659a9aec51f8170ca43..03b472c31a5178e758f8542d3541bfd773e5b2dd 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -44,7 +44,8 @@ from arraycontext import ( # noqa: F401 from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory) + _PytestPytatoJaxArrayContextFactory, + _PytestPytatoCUDAGraphArrayContextFactory) import logging @@ -93,6 +94,7 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts([ _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoCUDAGraphArrayContextFactory ])