From 79bd38b1c3678f11881939e5bdafd1e49b6c7ecf Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Thu, 4 Aug 2022 22:30:37 -0500 Subject: [PATCH 1/2] Implemented simple broadcasting --- .github/workflows/ci.yml | 2 +- doc/array.rst | 20 +++ pycuda/elementwise.py | 62 ++++++- pycuda/gpuarray.py | 263 +++++++++++++++++++++++++----- pycuda/reduction.py | 26 +++ pycuda/tools.py | 14 ++ setup.py | 2 +- test/test_driver.py | 2 + test/test_gpuarray.py | 342 ++++++++++++++++++++++++--------------- 9 files changed, 558 insertions(+), 175 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b90957af..9b3a95b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: uses: actions/setup-python@v1 with: # matches compat target in setup.py - python-version: '3.6' + python-version: '3.8' - name: "Main Script" run: | curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-flake8.sh diff --git a/doc/array.rst b/doc/array.rst index dacb6cbf..adbd9cfb 100644 --- a/doc/array.rst +++ b/doc/array.rst @@ -182,6 +182,10 @@ The :class:`GPUArray` Array Class .. method :: astype(dtype, stream=None) Return *self*, cast to *dtype*. + + .. method :: any(stream=None, allocator=None) + + .. method :: all(stream=None, allocator=None) .. attribute :: real @@ -352,6 +356,18 @@ Constructing :class:`GPUArray` Instances Join a sequence of arrays along a new axis. +.. function:: logical_and(x1, x2, /, out=None, * allocator=None) + + Returns the elementwise logical AND values of *x1* and *x2*. + +.. function:: logical_or(x1, x2, /, out=None, * allocator=None) + + Returns the elementwise logical OR values of *x1* and *x2*. + +.. function:: logical_not(x, /, out=None, * allocator=None) + + Returns the elementwise logical NOT of *x*. + Conditionals ^^^^^^^^^^^^ @@ -374,6 +390,10 @@ Reductions .. function:: sum(a, dtype=None, stream=None) +.. function:: any(a, stream=None, allocator=None) + +.. function:: all(a, stream=None, allocator=None) + .. function:: subset_sum(subset, a, dtype=None, stream=None) .. versionadded:: 2013.1 diff --git a/pycuda/elementwise.py b/pycuda/elementwise.py index 4e8601f0..7d633011 100644 --- a/pycuda/elementwise.py +++ b/pycuda/elementwise.py @@ -43,7 +43,6 @@ def get_elwise_module( after_loop="", ): from pycuda.compiler import SourceModule - return SourceModule( """ #include @@ -464,7 +463,20 @@ def get_linear_combination_kernel(summand_descriptors, dtype_z): @context_dependent_memoize -def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z): +def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z, + x_is_scalar=False, y_is_scalar=False): + """ + Returns a kernel corresponding to ``z = ax + by``. + + :arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`. + :arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`. + """ + out_t = dtype_to_ctype(dtype_z) + x = "x[0]" if x_is_scalar else "x[i]" + ax = f"a*(({out_t}) {x})" + y = "y[0]" if y_is_scalar else "y[i]" + by = f"b*(({out_t}) {y})" + result = f"{ax} + {by}" return get_elwise_kernel( "%(tp_x)s a, %(tp_x)s *x, %(tp_y)s b, %(tp_y)s *y, %(tp_z)s *z" % { @@ -472,7 +484,7 @@ def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z): "tp_y": dtype_to_ctype(dtype_y), "tp_z": dtype_to_ctype(dtype_z), }, - "z[i] = a*x[i] + b*y[i]", + f"z[i] = {result}", "axpbyz", ) @@ -488,7 +500,17 @@ def get_axpbz_kernel(dtype_x, dtype_z): @context_dependent_memoize -def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator): +def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator, + x_is_scalar=False, y_is_scalar=False): + """ + Returns a kernel corresponding to ``z = x (operator) y``. + + :arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`. + :arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`. + """ + x = "x[0]" if x_is_scalar else "x[i]" + y = "y[0]" if y_is_scalar else "y[i]" + result = f"{x} {operator} {y}" return get_elwise_kernel( "%(tp_x)s *x, %(tp_y)s *y, %(tp_z)s *z" % { @@ -496,7 +518,7 @@ def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator): "tp_y": dtype_to_ctype(dtype_y), "tp_z": dtype_to_ctype(dtype_z), }, - "z[i] = x[i] %s y[i]" % operator, + f"z[i] = {result}", "multiply", ) @@ -741,14 +763,40 @@ def get_if_positive_kernel(crit_dtype, dtype): @context_dependent_memoize -def get_scalar_op_kernel(dtype_x, dtype_y, operator): +def get_where_kernel(crit_dtype, dtype): + return get_elwise_kernel( + [ + VectorArg(crit_dtype, "crit"), + VectorArg(dtype, "then_"), + VectorArg(dtype, "else_"), + VectorArg(dtype, "result"), + ], + "result[i] = crit[i] != 0 ? then_[i] : else_[i]", + "if_positive", + ) + + +@context_dependent_memoize +def get_scalar_op_kernel(dtype_x, dtype_a, dtype_y, operator): return get_elwise_kernel( "%(tp_x)s *x, %(tp_a)s a, %(tp_y)s *y" % { "tp_x": dtype_to_ctype(dtype_x), "tp_y": dtype_to_ctype(dtype_y), - "tp_a": dtype_to_ctype(dtype_x), + "tp_a": dtype_to_ctype(dtype_a), }, "y[i] = x[i] %s a" % operator, "scalarop_kernel", ) + + +@context_dependent_memoize +def get_logical_not_kernel(dtype_x, dtype_out): + return get_elwise_kernel( + [ + VectorArg(dtype_x, "x"), + VectorArg(dtype_out, "out"), + ], + "out[i] = (x[i] == 0)", + "logical_not", + ) diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index c9cacc4b..cf363bb2 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -25,6 +25,18 @@ def _get_common_dtype(obj1, obj2): return _get_common_dtype_base(obj1, obj2, has_double_support()) +def _get_broadcasted_binary_op_result(obj1, obj2, + dtype_getter=_get_common_dtype): + + if obj1.shape == obj2.shape: + return obj1._new_like_me(dtype_getter(obj1, obj2)) + elif obj1.shape == (): + return obj2._new_like_me(dtype_getter(obj1, obj2)) + elif obj2.shape == (): + return obj1._new_like_me(dtype_getter(obj1, obj2)) + else: + raise NotImplementedError("Broadcasting binary operator with shapes:" + f" {obj1.shape}, {obj2.shape}.") # {{{ vector types @@ -141,20 +153,22 @@ def _make_binary_op(operator): raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - - if isinstance(other, GPUArray): - assert self.shape == other.shape - + if isinstance(other, GPUArray) and (self, GPUArray): if not other.flags.forc: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - result = self._new_like_me() + result = _get_broadcasted_binary_op_result(self, other) func = elementwise.get_binary_op_kernel( - self.dtype, other.dtype, result.dtype, operator - ) + self.dtype, + other.dtype, + result.dtype, + operator, + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) + func.prepared_async_call( self._grid, self._block, @@ -166,9 +180,12 @@ def _make_binary_op(operator): ) return result - else: # scalar operator + elif isinstance(self, GPUArray): # scalar operator + assert np.isscalar(other) result = self._new_like_me() - func = elementwise.get_scalar_op_kernel(self.dtype, result.dtype, operator) + func = elementwise.get_scalar_op_kernel(self.dtype, + np.dtype(type(other)), + result.dtype, operator) func.prepared_async_call( self._grid, self._block, @@ -179,6 +196,8 @@ def _make_binary_op(operator): self.mem_size, ) return result + else: + return AssertionError return func @@ -391,38 +410,41 @@ class GPUArray: def _axpbyz(self, selffac, other, otherfac, out, add_timer=None, stream=None): """Compute ``out = selffac * self + otherfac*other``, where `other` is a vector..""" - assert self.shape == other.shape if not self.flags.forc or not other.flags.forc: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - - func = elementwise.get_axpbyz_kernel(self.dtype, other.dtype, out.dtype) - + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) + func = elementwise.get_axpbyz_kernel( + self.dtype, other.dtype, out.dtype, + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) if add_timer is not None: add_timer( 3 * self.size, func.prepared_timed_call( - self._grid, + out._grid, selffac, - self.gpudata, + out.gpudata, otherfac, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ), ) else: func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, selffac, self.gpudata, otherfac, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -454,16 +476,26 @@ class GPUArray: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) + + func = elementwise.get_binary_op_kernel( + self.dtype, + other.dtype, + out.dtype, + "*", + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) - func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "*") func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, self.gpudata, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -500,17 +532,25 @@ class GPUArray: "only contiguous arrays may " "be used as arguments to this operation" ) - assert self.shape == other.shape + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) - func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "/") + func = elementwise.get_binary_op_kernel( + self.dtype, + other.dtype, + out.dtype, + "/", + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, self.gpudata, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -537,31 +577,35 @@ class GPUArray: if isinstance(other, GPUArray): # add another vector - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._axpbyz(1, other, 1, result) - else: + + elif np.isscalar(other): # add a scalar if other == 0: return self.copy() else: result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1, other, result) - + else: + return NotImplemented __radd__ = __add__ def __sub__(self, other): """Substract an array from an array or a scalar from an array.""" if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._axpbyz(1, other, -1, result) - else: + elif np.isscalar(other): if other == 0: return self.copy() else: # create a new array for the result result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1, -other, result) + else: + return NotImplemented def __rsub__(self, other): """Substracts an array by a scalar or an array:: @@ -593,11 +637,13 @@ class GPUArray: def __mul__(self, other): if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._elwise_multiply(other, result) - else: + elif np.isscalar(other): result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(other, 0, result) + else: + return NotImplemented def __rmul__(self, scalar): result = self._new_like_me(_get_common_dtype(self, scalar)) @@ -615,16 +661,17 @@ class GPUArray: x = self / n """ if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._div(other, result) - else: + elif np.isscalar(other): if other == 1: return self.copy() else: # create a new array for the result result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1 / other, 0, result) - + else: + return NotImplemented __truediv__ = __div__ def __rdiv__(self, other): @@ -882,6 +929,12 @@ class GPUArray: return result + def any(self, stream=None, allocator=None): + return any(self, stream=stream, allocator=allocator) + + def all(self, stream=None, allocator=None): + return all(self, stream=stream, allocator=allocator) + def reshape(self, *shape, **kwargs): """Gives a new shape to an array without changing its data.""" @@ -1852,7 +1905,8 @@ def stack(arrays, axis=0, allocator=None): input_ndim = arrays[0].ndim axis = input_ndim if axis == -1 else axis - if not all(ary.shape == input_shape for ary in arrays[1:]): + import builtins + if not builtins.all(ary.shape == input_shape for ary in arrays[1:]): raise ValueError("arrays must have the same shape") if not (0 <= axis <= input_ndim): @@ -1927,6 +1981,32 @@ def if_positive(criterion, then_, else_, out=None, stream=None): return out +def where(criterion, then_, else_, out=None, stream=None): + if (criterion.shape != then_.shape != else_.shape): + raise NotImplementedError("shape broadcast not implemented") + + if (then_.dtype != else_.dtype): + raise NotImplementedError("dtype broadcast not implemented") + + func = elementwise.get_where_kernel(criterion.dtype, then_.dtype) + + if out is None: + out = empty_like(then_) + + func.prepared_async_call( + criterion._grid, + criterion._block, + stream, + criterion.gpudata, + then_.gpudata, + else_.gpudata, + out.gpudata, + criterion.size, + ) + + return out + + def _make_binary_minmax_func(which): def f(a, b, out=None, stream=None): if isinstance(a, GPUArray) and isinstance(b, GPUArray): @@ -1979,6 +2059,20 @@ def sum(a, dtype=None, stream=None, allocator=None): return krnl(a, stream=stream, allocator=allocator) +def any(a, stream=None, allocator=None): + from pycuda.reduction import get_any_kernel + + krnl = get_any_kernel(np.dtype(bool), a.dtype) + return krnl(a, stream=stream, allocator=allocator) + + +def all(a, stream=None, allocator=None): + from pycuda.reduction import get_all_kernel + + krnl = get_all_kernel(np.dtype(bool), a.dtype) + return krnl(a, stream=stream, allocator=allocator) + + def subset_sum(subset, a, dtype=None, stream=None, allocator=None): from pycuda.reduction import get_subset_sum_kernel @@ -2033,4 +2127,95 @@ subset_max = _make_subset_minmax_kernel("max") # }}} + +# {{{ logical ops + +def _logical_op(x1, x2, out, allocator, operator): + assert operator in ["&&", "||"] + allocator = ( + allocator + or getattr(x1, "allocator", None) + or getattr(x2, "allocator", None) + or drv.mem_alloc) + + if np.isscalar(x1) and np.isscalar(x2): + if out is None: + out = empty(shape=(), dtype=np.bool_, allocator=allocator) + + if operator == "&&": + out[:] = np.logical_and(x1, x2) + else: + out[:] = np.logical_or(x1, x2) + elif np.isscalar(x1) or np.isscalar(x2): + scalar_arg, = [x for x in (x1, x2) if np.isscalar(x)] + ary_arg, = [x for x in (x1, x2) if not np.isscalar(x)] + if not isinstance(ary_arg, GPUArray): + raise ValueError("logical_and can take either scalar or GPUArrays" + " as inputs") + + out = out or ary_arg._new_like_me(dtype=np.bool_) + + assert out.shape == ary_arg.shape and out.dtype == np.bool_ + + func = elementwise.get_scalar_op_kernel(ary_arg.dtype, + np.dtype(type(scalar_arg)), + out.dtype, + operator) + + func.prepared_async_call(out._grid, out._block, + None, + ary_arg.gpudata, + scalar_arg, + out.gpudata, + out.mem_size) + else: + if not (isinstance(x1, GPUArray) and isinstance(x2, GPUArray)): + raise ValueError("logical_and can take either scalar or GPUArrays" + " as inputs") + if x1.shape != x2.shape: + raise NotImplementedError("Broadcasting not supported") + + if out is None: + out = x1._new_like_me(dtype=np.bool_) + + assert out.shape == x1.shape and out.dtype == np.bool_ + + func = elementwise.get_binary_op_kernel( + x1.dtype, x2.dtype, out.dtype, operator + ) + func.prepared_async_call(out._grid, out._block, + None, + x1.gpudata, + x2.gpudata, + out.gpudata, + out.mem_size) + + return out + + +def logical_and(x1, x2, /, out=None, *, allocator=None): + return _logical_op(x1, x2, out, allocator, "&&") + + +def logical_or(x1, x2, /, out=None, *, allocator=None): + return _logical_op(x1, x2, out, allocator, "||") + + +def logical_not(x, /, out=None, *, allocator=drv.mem_alloc): + if np.isscalar(x): + out = out or empty(shape=(), dtype=np.bool_, allocator=allocator) + out[:] = np.logical_not(x) + else: + out = out or empty(shape=x.shape, dtype=np.bool_, allocator=allocator) + func = elementwise.get_logical_not_kernel(x.dtype, out.dtype) + func.prepared_async_call(out._grid, out._block, + None, + x.gpudata, + out.gpudata, + out.mem_size) + + return out + +# }}} + # vim: foldmethod=marker diff --git a/pycuda/reduction.py b/pycuda/reduction.py index 2651353f..deb254aa 100644 --- a/pycuda/reduction.py +++ b/pycuda/reduction.py @@ -360,6 +360,32 @@ def get_sum_kernel(dtype_out, dtype_in): ) +@context_dependent_memoize +def get_any_kernel(dtype_out, dtype_in): + if dtype_out is None: + dtype_out = dtype_in + + return ReductionKernel( + dtype_out, + "0", + "(a != 0) || (b != 0)", + arguments="const {tp} *in".format(tp=dtype_to_ctype(dtype_in)), + ) + + +@context_dependent_memoize +def get_all_kernel(dtype_out, dtype_in): + if dtype_out is None: + dtype_out = dtype_in + + return ReductionKernel( + dtype_out, + "1", + "(a != 0) && (b != 0)", + arguments="const {tp} *in".format(tp=dtype_to_ctype(dtype_in)), + ) + + @context_dependent_memoize def get_subset_sum_kernel(dtype_out, dtype_subset, dtype_in): if dtype_out is None: diff --git a/pycuda/tools.py b/pycuda/tools.py index 05ac3c52..a92883c9 100644 --- a/pycuda/tools.py +++ b/pycuda/tools.py @@ -527,6 +527,20 @@ def mark_cuda_test(inner_f): return mark_test.cuda(f) +def init_cuda_context_fixture(): + import pycuda.driver as cuda + cuda.init() + ctx = make_default_context() + assert isinstance(ctx.get_device().name(), str) + assert isinstance(ctx.get_device().compute_capability(), tuple) + assert isinstance(ctx.get_device().get_attributes(), dict) + yield + + from gc import collect + ctx.pop() + clear_context_caches() + collect() + # }}} diff --git a/setup.py b/setup.py index 962b3f41..9adccd87 100644 --- a/setup.py +++ b/setup.py @@ -225,7 +225,7 @@ def main(): setup_requires=[ "numpy>=1.6", ], - python_requires="~=3.6", + python_requires="~=3.8", install_requires=[ "pytools>=2011.2", "appdirs>=1.4.0", diff --git a/test/test_driver.py b/test/test_driver.py index 98f3c8aa..9deae3be 100644 --- a/test/test_driver.py +++ b/test/test_driver.py @@ -1071,6 +1071,7 @@ class WrappedAllocation(drv.PointerHolderBase): return int(self.wrapped) +@mark_cuda_test def test_pointer_holder_base(): alloc = WrappedAllocation(drv.mem_alloc(1024)) ary = gpuarray.GPUArray((1024,), np.uint8, gpudata=alloc) @@ -1102,6 +1103,7 @@ class CudaArrayInterfaceImpl: return self._ptr +@mark_cuda_test def test_pass_cai_array(): dtype = np.int32 size = 1024 diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 0246fc27..cb246dcf 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -3,17 +3,23 @@ import numpy as np import numpy.linalg as la import sys -from pycuda.tools import mark_cuda_test +from pycuda.tools import init_cuda_context_fixture from pycuda.characterize import has_double_support import pycuda.gpuarray as gpuarray import pycuda.driver as drv from pycuda.compiler import SourceModule +import pytest +@pytest.fixture(autouse=True) +def init_cuda_context(): + yield from init_cuda_context_fixture() + + +@pytest.mark.cuda class TestGPUArray: - @mark_cuda_test def test_pow_array(self): a = np.array([1, 2, 3, 4, 5]).astype(np.float32) a_gpu = gpuarray.to_gpu(a) @@ -30,9 +36,9 @@ class TestGPUArray: a_gpu = a_gpu.get() np.testing.assert_allclose(pow(a, b), a_gpu, rtol=1e-6) - @mark_cuda_test - def test_pow_number(self): - a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_pow_number(self, dtype): + a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(dtype) a_gpu = gpuarray.to_gpu(a) result = pow(a_gpu, 2).get() @@ -42,7 +48,6 @@ class TestGPUArray: a_gpu = a_gpu.get() np.testing.assert_allclose(a ** 2, a_gpu, rtol=1e-6) - @mark_cuda_test def test_rpow_array(self): scalar = np.random.rand() a = abs(np.random.rand(10)) @@ -57,18 +62,15 @@ class TestGPUArray: result = (a_gpu ** scalar).get() np.testing.assert_allclose(a ** scalar, result) - @mark_cuda_test def test_numpy_integer_shape(self): gpuarray.empty(np.int32(17), np.float32) gpuarray.empty((np.int32(17), np.int32(17)), np.float32) - @mark_cuda_test def test_ndarray_shape(self): gpuarray.empty(np.array(3), np.float32) gpuarray.empty(np.array([3]), np.float32) gpuarray.empty(np.array([2, 3]), np.float32) - @mark_cuda_test def test_abs(self): a = -gpuarray.arange(111, dtype=np.float32) res = a.get() @@ -84,13 +86,11 @@ class TestGPUArray: assert abs(res[i]) >= 0 assert res[i] == i - @mark_cuda_test def test_len(self): a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) a_cpu = gpuarray.to_gpu(a) assert len(a_cpu) == 10 - @mark_cuda_test def test_multiply(self): """Test the muliplication of an array with a scalar. """ @@ -103,7 +103,6 @@ class TestGPUArray: assert (a * scalar == a_doubled).all() - @mark_cuda_test def test_rmul_yields_right_type(self): a = np.array([1, 2, 3, 4, 5]).astype(np.float32) a_gpu = gpuarray.to_gpu(a) @@ -114,20 +113,29 @@ class TestGPUArray: two_a = np.float32(2) * a_gpu assert isinstance(two_a, gpuarray.GPUArray) - @mark_cuda_test def test_multiply_array(self): """Test the multiplication of two arrays.""" a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) + b = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) + c = np.array(2) a_gpu = gpuarray.to_gpu(a) - b_gpu = gpuarray.to_gpu(a) + b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) + + a_mul_b = (a_gpu * b_gpu).get() + assert (a * b == a_mul_b).all() + + b_mul_a = (b_gpu * a_gpu).get() + assert (b * a == b_mul_a).all() - a_squared = (b_gpu * a_gpu).get() + a_mul_c = (a_gpu * c_gpu).get() + assert (a * c == a_mul_c).all() - assert (a * a == a_squared).all() + b_mul_c = (b_gpu * c_gpu).get() + assert (b * c == b_mul_c).all() - @mark_cuda_test def test_unit_multiply_array(self): a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) @@ -136,17 +144,25 @@ class TestGPUArray: np.testing.assert_allclose(+a_gpu.get(), +a, rtol=1e-6) np.testing.assert_allclose(-a_gpu.get(), -a, rtol=1e-6) - @mark_cuda_test def test_addition_array(self): """Test the addition of two arrays.""" a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) a_gpu = gpuarray.to_gpu(a) + b = np.array(1).astype(np.float32) + b_gpu = gpuarray.to_gpu(b) a_added = (a_gpu + a_gpu).get() + a_added_scalar = (a_gpu + 1).get() + scalar_added_a = (1 + a_gpu).get() + a_gpu_pl_b_gpu = (a_gpu + b_gpu).get() + b_gpu_pl_a_gpu = (b_gpu + a_gpu).get() assert (a + a == a_added).all() + assert (a + 1 == a_added_scalar).all() + assert (1 + a == scalar_added_a).all() + assert (a + b == a_gpu_pl_b_gpu).all() + assert (b + a == b_gpu_pl_a_gpu).all() - @mark_cuda_test def test_iaddition_array(self): """Test the inplace addition of two arrays.""" @@ -157,7 +173,6 @@ class TestGPUArray: assert (a + a == a_added).all() - @mark_cuda_test def test_addition_scalar(self): """Test the addition of an array and a scalar.""" @@ -167,7 +182,6 @@ class TestGPUArray: assert (7 + a == a_added).all() - @mark_cuda_test def test_iaddition_scalar(self): """Test the inplace addition of an array and a scalar.""" @@ -178,15 +192,16 @@ class TestGPUArray: assert (7 + a == a_added).all() - @mark_cuda_test def test_substract_array(self): """Test the subtraction of two arrays.""" # test data a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) b = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) + c = np.array(1).astype(np.float32) a_gpu = gpuarray.to_gpu(a) b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) result = (a_gpu - b_gpu).get() assert (a - b == result).all() @@ -194,7 +209,12 @@ class TestGPUArray: result = (b_gpu - a_gpu).get() assert (b - a == result).all() - @mark_cuda_test + result = (a_gpu - c_gpu).get() + assert (a - c == result).all() + + result = (c_gpu - a_gpu).get() + assert (c - a == result).all() + def test_substract_scalar(self): """Test the subtraction of an array and a scalar.""" @@ -210,7 +230,6 @@ class TestGPUArray: result = (7 - a_gpu).get() assert (7 - a == result).all() - @mark_cuda_test def test_divide_scalar(self): """Test the division of an array and a scalar.""" @@ -223,16 +242,17 @@ class TestGPUArray: result = (2 / a_gpu).get() assert (2 / a == result).all() - @mark_cuda_test def test_divide_array(self): """Test the division of an array and a scalar. """ # test data a = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) b = np.array([10, 10, 10, 10, 10, 10, 10, 10, 10, 10]).astype(np.float32) + c = np.array(2) a_gpu = gpuarray.to_gpu(a) b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) a_divide = (a_gpu / b_gpu).get() assert (np.abs(a / b - a_divide) < 1e-3).all() @@ -240,7 +260,12 @@ class TestGPUArray: a_divide = (b_gpu / a_gpu).get() assert (np.abs(b / a - a_divide) < 1e-3).all() - @mark_cuda_test + a_divide = (a_gpu / c_gpu).get() + assert (np.abs(a / c - a_divide) < 1e-3).all() + + a_divide = (c_gpu / a_gpu).get() + assert (np.abs(c / a - a_divide) < 1e-3).all() + def test_random(self): from pycuda.curandom import rand as curand @@ -255,7 +280,6 @@ class TestGPUArray: assert (0 <= a).all() assert (a < 1).all() - @mark_cuda_test def test_curand_wrappers(self): from pycuda.curandom import get_curand_version @@ -330,7 +354,6 @@ class TestGPUArray: # # Compare with scipy.stats.poisson.pmf(v - 1, v) # assert np.isclose(0.12511, tmp, atol=0.002) - @mark_cuda_test def test_array_gt(self): """Test whether array contents are > the other array's contents""" @@ -343,7 +366,6 @@ class TestGPUArray: assert result[0] assert not result[1] - @mark_cuda_test def test_array_lt(self): """Test whether array contents are < the other array's contents""" @@ -356,7 +378,6 @@ class TestGPUArray: assert result[0] assert not result[1] - @mark_cuda_test def test_array_le(self): """Test whether array contents are <= the other array's contents""" @@ -370,7 +391,6 @@ class TestGPUArray: assert result[1] assert not result[2] - @mark_cuda_test def test_array_ge(self): """Test whether array contents are >= the other array's contents""" @@ -384,7 +404,6 @@ class TestGPUArray: assert result[1] assert not result[2] - @mark_cuda_test def test_array_eq(self): """Test whether array contents are == the other array's contents""" @@ -397,7 +416,6 @@ class TestGPUArray: assert not result[0] assert result[1] - @mark_cuda_test def test_array_ne(self): """Test whether array contents are != the other array's contents""" @@ -410,7 +428,6 @@ class TestGPUArray: assert result[0] assert not result[1] - @mark_cuda_test def test_nan_arithmetic(self): def make_nan_contaminated_vector(size): shape = (size,) @@ -435,7 +452,6 @@ class TestGPUArray: assert (np.isnan(ab) == np.isnan(ab_gpu)).all() - @mark_cuda_test def test_elwise_kernel(self): from pycuda.curandom import rand as curand @@ -455,7 +471,6 @@ class TestGPUArray: assert la.norm((c_gpu - (5 * a_gpu + 6 * b_gpu)).get()) < 1e-5 - @mark_cuda_test def test_ranged_elwise_kernel(self): from pycuda.elementwise import ElementwiseKernel @@ -479,7 +494,6 @@ class TestGPUArray: assert la.norm(a_cpu - a_gpu.get()) == 0, i - @mark_cuda_test def test_take(self): idx = gpuarray.arange(0, 10000, 2, dtype=np.uint32) for dtype in [np.float32, np.complex64]: @@ -489,12 +503,10 @@ class TestGPUArray: assert (a_host[idx.get()] == result.get()).all() - @mark_cuda_test def test_arange(self): a = gpuarray.arange(12, dtype=np.float32) assert (np.arange(12, dtype=np.float32) == a.get()).all() - @mark_cuda_test def test_ones(self): ones = np.ones(10) @@ -503,35 +515,30 @@ class TestGPUArray: np.testing.assert_allclose(ones, ones_gpu.get(), rtol=1e-6) assert ones.dtype == ones_gpu.dtype - @mark_cuda_test - def test_stack(self): - - orders = ["F", "C"] - input_dims_lst = [0, 1, 2] + @pytest.mark.parametrize("order", ["F", "C"]) + @pytest.mark.parametrize("input_dims", [0, 1, 2]) + def test_stack(self, order, input_dims): - for order in orders: - for input_dims in input_dims_lst: - shape = (2, 2, 2)[:input_dims] - axis = -1 if order == "F" else 0 + shape = (2, 2, 2)[:input_dims] + axis = -1 if order == "F" else 0 - from numpy.random import default_rng - rng = default_rng() - x_in = rng.random(size=shape) - y_in = rng.random(size=shape) - x_in = x_in if order == "C" else np.asfortranarray(x_in) - y_in = y_in if order == "C" else np.asfortranarray(y_in) + from numpy.random import default_rng + rng = default_rng() + x_in = rng.random(size=shape) + y_in = rng.random(size=shape) + x_in = x_in if order == "C" else np.asfortranarray(x_in) + y_in = y_in if order == "C" else np.asfortranarray(y_in) - x_gpu = gpuarray.to_gpu(x_in) - y_gpu = gpuarray.to_gpu(y_in) + x_gpu = gpuarray.to_gpu(x_in) + y_gpu = gpuarray.to_gpu(y_in) - numpy_stack = np.stack((x_in, y_in), axis=axis) - gpuarray_stack = gpuarray.stack((x_gpu, y_gpu), axis=axis) + numpy_stack = np.stack((x_in, y_in), axis=axis) + gpuarray_stack = gpuarray.stack((x_gpu, y_gpu), axis=axis) - np.testing.assert_allclose(gpuarray_stack.get(), numpy_stack) + np.testing.assert_allclose(gpuarray_stack.get(), numpy_stack) - assert gpuarray_stack.shape == numpy_stack.shape + assert gpuarray_stack.shape == numpy_stack.shape - @mark_cuda_test def test_concatenate(self): from pycuda.curandom import rand as curand @@ -550,7 +557,6 @@ class TestGPUArray: assert cat.shape == cat_dev.shape - @mark_cuda_test def test_reverse(self): a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) a_cpu = gpuarray.to_gpu(a) @@ -562,7 +568,6 @@ class TestGPUArray: for i in range(0, 10): assert a[len(a) - 1 - i] == b[i] - @mark_cuda_test def test_sum(self): from pycuda.curandom import rand as curand @@ -575,7 +580,56 @@ class TestGPUArray: assert abs(sum_a_gpu - sum_a) / abs(sum_a) < 1e-4 - @mark_cuda_test + @pytest.mark.parametrize("dtype", [np.int32, np.bool, np.float32, np.float64]) + def test_any(self, dtype): + + ary_list = [np.ones(10, dtype), + np.zeros(1, dtype), + np.ones(1, dtype), + np.empty(10, dtype)] + + for ary in ary_list: + ary_gpu = gpuarray.to_gpu(ary) + any_ary = np.any(ary) + any_ary_gpu = ary_gpu.any().get() + np.testing.assert_array_equal(any_ary_gpu, any_ary) + assert any_ary_gpu.dtype == any_ary.dtype + + import itertools + for _array in list(itertools.product([0, 1], [0, 1], [0, 1])): + array = np.array(_array, dtype) + array_gpu = gpuarray.to_gpu(array) + any_array = np.any(array) + any_array_gpu = array_gpu.any().get() + + np.testing.assert_array_equal(any_array_gpu, any_array) + assert any_array_gpu.dtype == any_array.dtype + + @pytest.mark.parametrize("dtype", [np.int32, np.bool, np.float32, np.float64]) + def test_all(self, dtype): + + ary_list = [np.ones(10, dtype), + np.zeros(1, dtype), + np.ones(1, dtype), + np.empty(10, dtype)] + + for ary in ary_list: + ary_gpu = gpuarray.to_gpu(ary) + all_ary = np.all(ary) + all_ary_gpu = ary_gpu.all().get() + np.testing.assert_array_equal(all_ary_gpu, all_ary) + assert all_ary_gpu.dtype == all_ary.dtype + + import itertools + for _array in list(itertools.product([0, 1], [0, 1], [0, 1])): + array = np.array(_array, dtype) + array_gpu = gpuarray.to_gpu(array) + all_array = np.all(array) + all_array_gpu = array_gpu.all().get() + + np.testing.assert_array_equal(all_array_gpu, all_array) + assert all_array_gpu.dtype == all_array.dtype + def test_minmax(self): from pycuda.curandom import rand as curand @@ -594,7 +648,6 @@ class TestGPUArray: assert op_a_gpu == op_a, (op_a_gpu, op_a, dtype, what) - @mark_cuda_test def test_subset_minmax(self): from pycuda.curandom import rand as curand @@ -628,41 +681,38 @@ class TestGPUArray: assert min_a_gpu == min_a - @mark_cuda_test - def test_dot(self): + @pytest.mark.parametrize("sz", [2, + 3, + 4, + 5, + 6, + 7, + 31, + 32, + 33, + 127, + 128, + 129, + 255, + 256, + 257, + 16384 - 993, + 20000, + ]) + def test_dot(self, sz): from pycuda.curandom import rand as curand - for sz in [ - 2, - 3, - 4, - 5, - 6, - 7, - 31, - 32, - 33, - 127, - 128, - 129, - 255, - 256, - 257, - 16384 - 993, - 20000, - ]: - a_gpu = curand((sz,)) - a = a_gpu.get() - b_gpu = curand((sz,)) - b = b_gpu.get() + a_gpu = curand((sz,)) + a = a_gpu.get() + b_gpu = curand((sz,)) + b = b_gpu.get() - dot_ab = np.dot(a, b) + dot_ab = np.dot(a, b) - dot_ab_gpu = gpuarray.dot(a_gpu, b_gpu).get() + dot_ab_gpu = gpuarray.dot(a_gpu, b_gpu).get() - assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4 + assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4 - @mark_cuda_test def test_slice(self): from pycuda.curandom import rand as curand @@ -681,7 +731,6 @@ class TestGPUArray: assert la.norm(a_gpu_slice.get() - a_slice) == 0 - @mark_cuda_test def test_2d_slice_c(self): from pycuda.curandom import rand as curand @@ -701,7 +750,6 @@ class TestGPUArray: assert la.norm(a_gpu_slice.get() - a_slice) == 0 - @mark_cuda_test def test_2d_slice_f(self): from pycuda.curandom import rand as curand import pycuda.gpuarray as gpuarray @@ -725,7 +773,22 @@ class TestGPUArray: assert la.norm(a_gpu_slice.get() - a_slice) == 0 - @mark_cuda_test + def test_where(self): + a = np.array([1, 0, -1]) + b = np.array([2, 2, 2]) + c = np.array([3, 3, 3]) + + import pycuda.gpuarray as gpuarray + + a_gpu = gpuarray.to_gpu(a) + b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) + + result = gpuarray.where(a_gpu, b_gpu, c_gpu).get() + result_ref = np.where(a, b, c) + + np.testing.assert_allclose(result_ref, result, rtol=1e-5) + def test_if_positive(self): from pycuda.curandom import rand as curand @@ -746,7 +809,6 @@ class TestGPUArray: assert la.norm(max_a_b_gpu.get() - np.maximum(a, b)) == 0 assert la.norm(min_a_b_gpu.get() - np.minimum(a, b)) == 0 - @mark_cuda_test def test_take_put(self): for n in [5, 17, 333]: one_field_size = 8 @@ -768,7 +830,6 @@ class TestGPUArray: drv.Context.synchronize() - @mark_cuda_test def test_astype(self): from pycuda.curandom import rand as curand @@ -791,7 +852,6 @@ class TestGPUArray: assert a2.dtype == np.float32 assert la.norm(a - a2) / la.norm(a) < 1e-7 - @mark_cuda_test def test_complex_bits(self): from pycuda.curandom import rand as curand @@ -817,7 +877,7 @@ class TestGPUArray: # verify conj with out parameter z_out = z.astype(np.complex64) assert z_out is z.conj(out=z_out) - assert la.norm(z.get().conj() - z_out.get()) < 1e-7 + assert la.norm(z.get().conj() - z_out.get()) < 5e-6 # verify contiguity is preserved for order in ["C", "F"]: @@ -836,7 +896,6 @@ class TestGPUArray: assert zdata.imag.flags.f_contiguous assert zdata.conj().flags.f_contiguous - @mark_cuda_test def test_pass_slice_to_kernel(self): mod = SourceModule( """ @@ -857,9 +916,9 @@ class TestGPUArray: a = a_gpu.get() assert (a[255:257] == np.array([1, 2], np.float32)).all() - assert (a[255 * 256 - 1: 255 * 256 + 1] == np.array([2, 1], np.float32)).all() + np.testing.assert_array_equal(a[255 * 256 - 1: 255 * 256 + 1], + np.array([2, 1], np.float32)) - @mark_cuda_test def test_scan(self): from pycuda.scan import ExclusiveScanKernel, InclusiveScanKernel @@ -888,7 +947,6 @@ class TestGPUArray: assert (gpu_data.get() == desired_result).all() - @mark_cuda_test def test_stride_preservation(self): A = np.random.rand(3, 3) AT = A.T @@ -897,18 +955,15 @@ class TestGPUArray: print((AT_GPU.flags.f_contiguous, AT_GPU.flags.c_contiguous)) assert np.allclose(AT_GPU.get(), AT) - @mark_cuda_test def test_vector_fill(self): a_gpu = gpuarray.GPUArray(100, dtype=gpuarray.vec.float3) a_gpu.fill(gpuarray.vec.make_float3(0.0, 0.0, 0.0)) a = a_gpu.get() assert a.dtype == gpuarray.vec.float3 - @mark_cuda_test def test_create_complex_zeros(self): gpuarray.zeros(3, np.complex64) - @mark_cuda_test def test_reshape(self): a = np.arange(128).reshape(8, 16).astype(np.float32) a_gpu = gpuarray.to_gpu(a) @@ -941,7 +996,6 @@ class TestGPUArray: a_gpu = a_gpu.reshape((4, 32)) assert a_gpu.flags.c_contiguous - @mark_cuda_test def test_view(self): a = np.arange(128).reshape(8, 16).astype(np.float32) a_gpu = gpuarray.to_gpu(a) @@ -958,7 +1012,6 @@ class TestGPUArray: view = a_gpu.view(np.int16) assert view.shape == (8, 32) and view.dtype == np.int16 - @mark_cuda_test def test_squeeze(self): shape = (40, 2, 5, 100) a_cpu = np.random.random(size=shape) @@ -975,7 +1028,8 @@ class TestGPUArray: assert a_gpu_squeezed_slice.flags.c_contiguous # Check that we get the original values out - assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel()) + np.testing.assert_array_equal(a_gpu_slice.get().ravel(), + a_gpu_squeezed_slice.get().ravel()) # Slice with length 1 on dimensions 2 a_gpu_slice = a_gpu[:, :, 2:3, :] @@ -988,9 +1042,9 @@ class TestGPUArray: assert not a_gpu_squeezed_slice.flags.c_contiguous # Check that we get the original values out - assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel()) + np.testing.assert_array_equal(a_gpu_slice.get().ravel(), + a_gpu_squeezed_slice.get().ravel()) - @mark_cuda_test def test_struct_reduce(self): preamble = """ struct minmax_collector @@ -1062,7 +1116,6 @@ class TestGPUArray: assert minmax["cur_min"] == np.min(a) assert minmax["cur_max"] == np.max(a) - @mark_cuda_test def test_reduce_out(self): from pycuda.curandom import rand as curand @@ -1080,14 +1133,15 @@ class TestGPUArray: assert np.alltrue(a.max(axis=1) == max_gpu.get()) - @mark_cuda_test def test_sum_allocator(self): # FIXME from pytest import skip skip("https://github.com/inducer/pycuda/issues/163") - # crashes with terminate called after throwing an instance of 'pycuda::error' - # what(): explicit_context_dependent failed: invalid device context - no currently active context? + # crashes with terminate called after throwing an instance + # of 'pycuda::error' + # what(): explicit_context_dependent failed: invalid device context - + # no currently active context? import pycuda.tools @@ -1107,7 +1161,6 @@ class TestGPUArray: assert b.allocator == a.allocator assert c.allocator == pool.allocate - @mark_cuda_test def test_dot_allocator(self): # FIXME from pytest import skip @@ -1139,7 +1192,6 @@ class TestGPUArray: assert dot_gpu_1.allocator == a_gpu.allocator assert dot_gpu_2.allocator == pool.allocate - @mark_cuda_test def test_view_and_strides(self): from pycuda.curandom import rand as curand @@ -1152,7 +1204,6 @@ class TestGPUArray: assert np.array_equal(y.get(), X.get()[:3, :5]) - @mark_cuda_test def test_scalar_comparisons(self): a = np.array([1.0, 0.25, 0.1, -0.1, 0.0]) a_gpu = gpuarray.to_gpu(a) @@ -1173,7 +1224,6 @@ class TestGPUArray: x = (a == 1).astype(a.dtype) assert (x == x_gpu.get()).all() - @mark_cuda_test def test_minimum_maximum_scalar(self): from pycuda.curandom import rand as curand @@ -1189,17 +1239,14 @@ class TestGPUArray: assert la.norm(max_a0_gpu.get() - np.maximum(a, 0)) == 0 assert la.norm(min_a0_gpu.get() - np.minimum(0, a)) == 0 - @mark_cuda_test def test_transpose(self): from pycuda.curandom import rand as curand a_gpu = curand((10, 20, 30)) a = a_gpu.get() - # assert np.allclose(a_gpu.transpose((1,2,0)).get(), a.transpose((1,2,0))) # not contiguous assert np.allclose(a_gpu.T.get(), a.T) - @mark_cuda_test def test_newaxis(self): from pycuda.curandom import rand as curand @@ -1212,7 +1259,6 @@ class TestGPUArray: assert b_gpu.shape == b.shape assert b_gpu.strides == b.strides - @mark_cuda_test def test_copy(self): from pycuda.curandom import rand as curand @@ -1251,7 +1297,6 @@ class TestGPUArray: a_gpu.get()[start:stop:step, :, start:stop:step], ) - @mark_cuda_test def test_get_set(self): import pycuda.gpuarray as gpuarray @@ -1265,7 +1310,6 @@ class TestGPUArray: assert np.allclose(a_gpu.get(), a) assert np.allclose(a_gpu[1:3, 1:3, 1:3].get(), a[1:3, 1:3, 1:3]) - @mark_cuda_test def test_zeros_like_etc(self): shape = (16, 16) a = np.random.randn(*shape).astype(np.float32) @@ -1315,6 +1359,50 @@ class TestGPUArray: assert new_z.dtype == np.complex64 assert new_z.shape == arr.shape + def test_logical_and_or(self): + rng = np.random.default_rng(seed=0) + for op in ["logical_and", "logical_or"]: + x_np = rng.random((10, 4)) + y_np = rng.random((10, 4)) + zeros_np = np.zeros((10, 4)) + ones_np = np.ones((10, 4)) + + x_cu = gpuarray.to_gpu(x_np) + y_cu = gpuarray.to_gpu(y_np) + zeros_cu = gpuarray.zeros((10, 4), "float64") + ones_cu = gpuarray.ones((10, 4)) + + np.testing.assert_array_equal( + getattr(gpuarray, op)(x_cu, y_cu).get(), + getattr(np, op)(x_np, y_np)) + np.testing.assert_array_equal( + getattr(gpuarray, op)(x_cu, ones_cu).get(), + getattr(np, op)(x_np, ones_np)) + np.testing.assert_array_equal( + getattr(gpuarray, op)(x_cu, zeros_cu).get(), + getattr(np, op)(x_np, zeros_np)) + np.testing.assert_array_equal( + getattr(gpuarray, op)(x_cu, 1.0).get(), + getattr(np, op)(x_np, ones_np)) + np.testing.assert_array_equal( + getattr(gpuarray, op)(x_cu, 0.0).get(), + getattr(np, op)(x_np, 0.0)) + + def test_logical_not(self): + rng = np.random.default_rng(seed=0) + x_np = rng.random((10, 4)) + x_cu = gpuarray.to_gpu(x_np) + + np.testing.assert_array_equal( + gpuarray.logical_not(x_cu).get(), + np.logical_not(x_np)) + np.testing.assert_array_equal( + gpuarray.logical_not(gpuarray.zeros(10, "float64")).get(), + np.logical_not(np.zeros(10))) + np.testing.assert_array_equal( + gpuarray.logical_not(gpuarray.ones(10)).get(), + np.logical_not(np.ones(10))) + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. -- GitLab From c15519f2a5eb9b314addbf71c43a46dac913787d Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Thu, 4 Aug 2022 23:17:38 -0500 Subject: [PATCH 2/2] fixed merge conflicts --- pycuda/elementwise.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pycuda/elementwise.py b/pycuda/elementwise.py index a3f23b20..7d633011 100644 --- a/pycuda/elementwise.py +++ b/pycuda/elementwise.py @@ -782,12 +782,8 @@ def get_scalar_op_kernel(dtype_x, dtype_a, dtype_y, operator): "%(tp_x)s *x, %(tp_a)s a, %(tp_y)s *y" % { "tp_x": dtype_to_ctype(dtype_x), - "tp_a": dtype_to_ctype(dtype_a), "tp_y": dtype_to_ctype(dtype_y), -<<<<<<< HEAD "tp_a": dtype_to_ctype(dtype_a), -======= ->>>>>>> main }, "y[i] = x[i] %s a" % operator, "scalarop_kernel", -- GitLab