diff --git a/doc/source/array.rst b/doc/source/array.rst index 3b12f04ad10fd650349e9c8d6412e2a1e7a3a83f..de1cc66ed3b1ecc09b4b999a95b04461215a9b55 100644 --- a/doc/source/array.rst +++ b/doc/source/array.rst @@ -80,6 +80,12 @@ The :class:`GPUArray` Array Class .. versionadded: 2011.1 + .. attribute :: __cuda_array_interface__ + + Return a `CUDA Array Interface + `_ + dict describing this array's data. + .. method :: __len__() Returns the size of the leading dimension of *self*. diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index ea6a804f0edf67d3353a216169f8e4db1d0cabdd..2c401b88e51257d1d649f1b0c69a8244c02f0c80 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -191,6 +191,36 @@ only the second:: func(numpy.intp(do2_ptr), block = (32, 1, 1), grid=(1, 1)) print("doubled second only", array1, array2, "\n") +Interoperability With Other Libraries Using The CUDA Array Interface +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Kernel calls can be passed arrays from other CUDA libraries that support the +`CUDA Array Interface +`__. For +example, to double a `CuPy `_ array:: + + import cupy as cp + + cupy_a = cp.random.randn(4, 4).astype(cp.float32) + func = mod.get_function("double_array") + func(cupy_a, block=(4, 4, 1), grid=(1, 1)) + +:class:`~pycuda.gpuarray.GPUArray` implements the CUDA Array Interface, so its +instances can be passed into functions from other libraries that support it. +For example, to double a PyCUDA GPU Array using a `Numba +`_ kernel:: + + from numba import cuda + + a_gpu = gpuarray.to_gpu(numpy.random.randn(4, 4).astype(numpy.float32)) + + @cuda.jit + def double(x): + i, j = cuda.grid(2) + x[i, j] *= 2 + + double[(4, 4), (1, 1)](a_gpu) + Where to go from here --------------------- diff --git a/doc/source/util.rst b/doc/source/util.rst index 0cd85d698c2aa1724e96e9fe3ca63b97106c015e..f83907d69a1a86b8845844645a98cfb79ef032c5 100644 --- a/doc/source/util.rst +++ b/doc/source/util.rst @@ -21,6 +21,13 @@ It uses :func:`pycuda.tools.make_default_context` to create a compute context. on :data:`device`. This context is created by calling :func:`pycuda.tools.make_default_context`. +.. module:: pycuda.autoprimaryctx + +The module :mod:`pycuda.autoprimaryctx` is similar to :mod:`pycuda.autoinit`, +except that it retains the device primary context instead of creating a new +context in :func:`pycuda.tools.make_default_context`. Notably, it also +has ``device`` and ``context`` attributes. + Choice of Device ---------------- diff --git a/examples/cai_cupy_arrays.py b/examples/cai_cupy_arrays.py new file mode 100644 index 0000000000000000000000000000000000000000..2c524ba945a5bd61c593c1081b46a2ba1ddf080f --- /dev/null +++ b/examples/cai_cupy_arrays.py @@ -0,0 +1,33 @@ +# Copyright 2008-2021 Andreas Kloeckner +# Copyright 2021 NVIDIA Corporation + +import pycuda.autoinit # noqa +from pycuda.compiler import SourceModule + +import cupy as cp + + +# Create a CuPy array (and a copy for comparison later) +cupy_a = cp.random.randn(4, 4).astype(cp.float32) +original = cupy_a.copy() + + +# Create a kernel +mod = SourceModule(""" + __global__ void doublify(float *a) + { + int idx = threadIdx.x + threadIdx.y*4; + a[idx] *= 2; + } + """) + +func = mod.get_function("doublify") + +# Invoke PyCUDA kernel on a CuPy array +func(cupy_a, block=(4, 4, 1), grid=(1, 1), shared=0) + +# Demonstrate that our CuPy array was modified in place by the PyCUDA kernel +print("original array:") +print(original) +print("doubled with kernel:") +print(cupy_a) diff --git a/examples/cai_numba.py b/examples/cai_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..0a94ee48f495e2a152d219ab7aaf02b8d7c87d1c --- /dev/null +++ b/examples/cai_numba.py @@ -0,0 +1,34 @@ +# Copyright 2008-2021 Andreas Kloeckner +# Copyright 2021 NVIDIA Corporation + +from numba import cuda + +import pycuda.driver as pycuda +# We use autoprimaryctx instead of autoinit because Numba can only operate on a +# primary context +import pycuda.autoprimaryctx # noqa +import pycuda.gpuarray as gpuarray + +import numpy + + +# Create a PyCUDA gpuarray +a_gpu = gpuarray.to_gpu(numpy.random.randn(4, 4).astype(numpy.float32)) +print("original array:") +print(a_gpu) + + +# A standard Numba kernel that doubles its input array +@cuda.jit +def double(x): + i, j = cuda.grid(2) + + if i < x.shape[0] and j < x.shape[1]: + x[i, j] *= 2 + + +# Call the Numba kernel on the PyCUDA gpuarray, using the CUDA Array Interface +# transparently +double[(4, 4), (1, 1)](a_gpu) +print("doubled with numba:") +print(a_gpu) diff --git a/pycuda/autoprimaryctx.py b/pycuda/autoprimaryctx.py new file mode 100644 index 0000000000000000000000000000000000000000..537c8610261ba97c07605377297b103ffe2855ea --- /dev/null +++ b/pycuda/autoprimaryctx.py @@ -0,0 +1,31 @@ +import pycuda.driver as cuda +import atexit + +# Initialize CUDA +cuda.init() + +from pycuda.tools import make_default_context # noqa: E402 + + +def _retain_primary_context(dev): + context = dev.retain_primary_context() + context.push() + return context + + +global context +context = make_default_context(_retain_primary_context) +device = context.get_device() + + +def _finish_up(): + global context + context.pop() + context = None + + from pycuda.tools import clear_context_caches + + clear_context_caches() + + +atexit.register(_finish_up) diff --git a/pycuda/driver.py b/pycuda/driver.py index 6bfd097e560798ed6a658628fec0536b279b29e5..47d15b196bab8c3d9ca996d1ef43fed2fa9dd9c0 100644 --- a/pycuda/driver.py +++ b/pycuda/driver.py @@ -1,3 +1,8 @@ +__copyright__ = """ +Copyright 2008-2021 Andreas Kloeckner +Copyright 2021 NVIDIA Corporation +""" + import os import numpy as np @@ -208,6 +213,12 @@ def _add_functionality(): arg_data.append(_my_bytes(_memoryview(arg))) format += "%ds" % arg.itemsize else: + cai = getattr(arg, "__cuda_array_interface__", None) + if cai: + arg_data.append(cai["data"][0]) + format += "P" + continue + try: gpudata = np.uintp(arg.gpudata) except AttributeError: diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index f5908a064108bc142a6a5c21f47b7c33a16015e8..373cf00532b51c7e870d2e0c09e092e418e74378 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -1,3 +1,8 @@ +__copyright__ = """ +Copyright 2008-2021 Andreas Kloeckner +Copyright 2021 NVIDIA Corporation +""" + import numpy as np import pycuda.elementwise as elementwise from pytools import memoize, memoize_method @@ -252,6 +257,26 @@ class GPUArray: self._grid, self._block = splay(self.mem_size) + @property + def __cuda_array_interface__(self): + """Returns a CUDA Array Interface dictionary describing this array's + data.""" + if self.gpudata is not None: + ptr = int(self.gpudata) + else: + ptr = 0 + + return { + "shape": self.shape, + "strides": self.strides, + # data is a tuple: (ptr, readonly) - always export GPUArray + # instances as read-write + "data": (ptr, False), + "typestr": self.dtype.str, + "stream": None, + "version": 3 + } + @property def ndim(self): return len(self.shape) diff --git a/test/test_driver.py b/test/test_driver.py index c720cebf1ac7c9d9de8928f6213e948e0411a354..b022aa379aa0c9d8bd21bf7a7bfe30ac07f24cf6 100644 --- a/test/test_driver.py +++ b/test/test_driver.py @@ -1,3 +1,8 @@ +__copyright__ = """ +Copyright 2008-2021 Andreas Kloeckner +Copyright 2021 NVIDIA Corporation +""" + import numpy as np import numpy.linalg as la from pycuda.tools import mark_cuda_test, dtype_to_ctype @@ -144,6 +149,23 @@ class TestDriver: diff = (a_g * b_g).get() - a * b assert la.norm(diff) == 0 + @mark_cuda_test + def test_gpuarray_cai(self): + a = np.zeros(10, dtype=np.float32) + a_g = gpuarray.to_gpu(a) + cai = a_g.__cuda_array_interface__ + ptr = cai["data"][0] + masked = cai["data"][1] + + assert cai["shape"] == a.shape + assert cai["strides"] == a.strides + assert cai["typestr"] == a.dtype.str + assert isinstance(ptr, int) + assert ptr != 0 + assert not masked + assert cai["stream"] is None + assert cai["version"] == 3 + @mark_cuda_test def donottest_cublas_mixing(self): self.test_streamed_kernel() @@ -1054,6 +1076,55 @@ def test_pointer_holder_base(): print(ary.get()) +# A class to emulate an object from outside PyCUDA that implements the CUDA +# Array Interface +class CudaArrayInterfaceImpl: + def __init__(self, size, itemsize, dtype): + self._shape = (size,) + self._strides = (itemsize,) + self._typestr = dtype.str + self._ptr = drv.mem_alloc(size * itemsize) + + @property + def __cuda_array_interface__(self): + return { + "shape": self._shape, + "strides": self._strides, + "typestr": self._typestr, + "data": (int(self._ptr), False), + "stream": None, + "version": 3 + } + + @property + def ptr(self): + return self._ptr + + +def test_pass_cai_array(): + dtype = np.int32 + size = 1024 + np_array = np.arange(size, dtype=dtype) + cai_array = CudaArrayInterfaceImpl(size, np_array.itemsize, np_array.dtype) + + mod = SourceModule( + """ + __global__ void gpu_arange(int *x) + { + const int i = threadIdx.x; + x[i] = i; + } + """ + ) + + gpu_arange = mod.get_function("gpu_arange") + gpu_arange(cai_array, grid=(1,), block=(size, 1, 1)) + + host_array = np.empty_like(np_array) + drv.memcpy_dtoh(host_array, cai_array.ptr) + assert (host_array == np_array).all() + + def test_import_pyopencl_before_pycuda(): try: import pyopencl # noqa