From 5733d12cc7b18460acca3e51b47bb67db16918d1 Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Sun, 3 Jul 2022 11:40:25 -0500 Subject: [PATCH] Added bitwise operations + added revers op kernel --- pycuda/elementwise.py | 60 ++++++++++++-- pycuda/gpuarray.py | 158 ++++++++++++++++++++++++++---------- test/test_gpuarray.py | 185 ++++++++++++++++++++++++------------------ 3 files changed, 274 insertions(+), 129 deletions(-) diff --git a/pycuda/elementwise.py b/pycuda/elementwise.py index 4e8601f0..b07d9c1f 100644 --- a/pycuda/elementwise.py +++ b/pycuda/elementwise.py @@ -43,7 +43,6 @@ def get_elwise_module( after_loop="", ): from pycuda.compiler import SourceModule - return SourceModule( """ #include @@ -464,7 +463,20 @@ def get_linear_combination_kernel(summand_descriptors, dtype_z): @context_dependent_memoize -def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z): +def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z, + x_is_scalar=False, y_is_scalar=False): + """ + Returns a kernel corresponding to ``z = ax + by``. + + :arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`. + :arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`. + """ + out_t = dtype_to_ctype(dtype_z) + x = "x[0]" if x_is_scalar else "x[i]" + ax = f"a*(({out_t}) {x})" + y = "y[0]" if y_is_scalar else "y[i]" + by = f"b*(({out_t}) {y})" + result = f"{ax} + {by}" return get_elwise_kernel( "%(tp_x)s a, %(tp_x)s *x, %(tp_y)s b, %(tp_y)s *y, %(tp_z)s *z" % { @@ -472,7 +484,7 @@ def get_axpbyz_kernel(dtype_x, dtype_y, dtype_z): "tp_y": dtype_to_ctype(dtype_y), "tp_z": dtype_to_ctype(dtype_z), }, - "z[i] = a*x[i] + b*y[i]", + f"z[i] = {result}", "axpbyz", ) @@ -488,7 +500,17 @@ def get_axpbz_kernel(dtype_x, dtype_z): @context_dependent_memoize -def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator): +def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator, + x_is_scalar, y_is_scalar): + """ + Returns a kernel corresponding to ``z = x (operator) y``. + + :arg x_is_scalar: A :class:`bool` which is *True* only if `x` is a scalar :class:`gpuarray`. + :arg y_is_scalar: A :class:`bool` which is *True* only if `y` is a scalar :class:`gpuarray`. + """ + x = "x[0]" if x_is_scalar else "x[i]" + y = "y[0]" if y_is_scalar else "y[i]" + result = f"{x} {operator} {y}" return get_elwise_kernel( "%(tp_x)s *x, %(tp_y)s *y, %(tp_z)s *z" % { @@ -496,7 +518,7 @@ def get_binary_op_kernel(dtype_x, dtype_y, dtype_z, operator): "tp_y": dtype_to_ctype(dtype_y), "tp_z": dtype_to_ctype(dtype_z), }, - "z[i] = x[i] %s y[i]" % operator, + f"z[i] = {result}", "multiply", ) @@ -741,14 +763,38 @@ def get_if_positive_kernel(crit_dtype, dtype): @context_dependent_memoize -def get_scalar_op_kernel(dtype_x, dtype_y, operator): +def get_scalar_op_kernel(dtype_x, dtype_y, dtype_a, operator): + """ + Returns a kernel correpsonding to ``y = x (operator) a``, where, + - ``x`` is a :class:`pycuda.gpuarray.Array` + - ``a`` is a scalar + """ return get_elwise_kernel( "%(tp_x)s *x, %(tp_a)s a, %(tp_y)s *y" % { "tp_x": dtype_to_ctype(dtype_x), "tp_y": dtype_to_ctype(dtype_y), - "tp_a": dtype_to_ctype(dtype_x), + "tp_a": dtype_to_ctype(dtype_a), }, "y[i] = x[i] %s a" % operator, "scalarop_kernel", ) + + +@context_dependent_memoize +def get_reverse_scalar_op_kernel(dtype_a, dtype_y, dtype_x, operator): + """ + Returns a kernel correpsonding to ``y = a (operator) x``, where, + - ``x`` is a :class:`pycuda.gpuarray.Array` + - ``a`` is a scalar + """ + return get_elwise_kernel( + "%(tp_x)s *x, %(tp_a)s a, %(tp_y)s *y" + % { + "tp_x": dtype_to_ctype(dtype_x), + "tp_y": dtype_to_ctype(dtype_y), + "tp_a": dtype_to_ctype(dtype_a), + }, + "y[i] = a %s x[i]" % operator, + "reverse_scalarop_kernel", + ) diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index a3dbc100..50c706c3 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -25,6 +25,18 @@ def _get_common_dtype(obj1, obj2): return _get_common_dtype_base(obj1, obj2, has_double_support()) +def _get_broadcasted_binary_op_result(obj1, obj2, + dtype_getter=_get_common_dtype): + + if obj1.shape == obj2.shape: + return obj1._new_like_me(dtype_getter(obj1, obj2)) + elif obj1.shape == (): + return obj2._new_like_me(dtype_getter(obj1, obj2)) + elif obj2.shape == (): + return obj1._new_like_me(dtype_getter(obj1, obj2)) + else: + raise NotImplementedError("Broadcasting binary operator with shapes:" + f" {obj1.shape}, {obj2.shape}.") # {{{ vector types @@ -134,6 +146,17 @@ def splay(n, dev=None): # {{{ main GPUArray class +def _make_bitwise_binary_op(operator): + def func_bitwise(self, other): + if ( + (np.issubdtype(self.dtype, np.floating) + or np.issubdtype(other.dtype, np.floating)) + ): + raise TypeError("operation `"+operator+"` not supported for input type") + func = _make_binary_op(operator) + return func(self, other) + return func_bitwise + def _make_binary_op(operator): def func(self, other): @@ -141,34 +164,37 @@ def _make_binary_op(operator): raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - - if isinstance(other, GPUArray): - assert self.shape == other.shape - + if isinstance(other, GPUArray) and (self, GPUArray): if not other.flags.forc: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - result = self._new_like_me() + result = _get_broadcasted_binary_op_result(self, other) func = elementwise.get_binary_op_kernel( - self.dtype, other.dtype, result.dtype, operator - ) + self.dtype, + other.dtype, + result.dtype, + operator, + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) + func.prepared_async_call( - self._grid, - self._block, + result._grid, + result._block, None, self.gpudata, other.gpudata, result.gpudata, - self.mem_size, + result.mem_size, ) return result - else: # scalar operator + elif isinstance(self, GPUArray): # scalar operator + assert np.isscalar(other) result = self._new_like_me() - func = elementwise.get_scalar_op_kernel(self.dtype, result.dtype, operator) + func = elementwise.get_scalar_op_kernel(self.dtype, result.dtype, type(other), operator) func.prepared_async_call( self._grid, self._block, @@ -179,6 +205,22 @@ def _make_binary_op(operator): self.mem_size, ) return result + elif isinstance(other, GPUArray): # reverse scalar operator + assert np.isscalar(self) + result = other._new_like_me() + func = elementwise.get_reverse_scalar_op_kernel(type(self), result.dtype, other.dtype, operator) + func.prepared_async_call( + other._grid, + other._block, + None, + self, + other.gpudata, + result.gpudata, + other.mem_size, + ) + return result + else: + return NotImplemented return func @@ -391,38 +433,41 @@ class GPUArray: def _axpbyz(self, selffac, other, otherfac, out, add_timer=None, stream=None): """Compute ``out = selffac * self + otherfac*other``, where `other` is a vector..""" - assert self.shape == other.shape if not self.flags.forc or not other.flags.forc: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) - - func = elementwise.get_axpbyz_kernel(self.dtype, other.dtype, out.dtype) - + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) + func = elementwise.get_axpbyz_kernel( + self.dtype, other.dtype, out.dtype, + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) if add_timer is not None: add_timer( 3 * self.size, func.prepared_timed_call( - self._grid, + out._grid, selffac, - self.gpudata, + out.gpudata, otherfac, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ), ) else: func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, selffac, self.gpudata, otherfac, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -454,16 +499,26 @@ class GPUArray: raise RuntimeError( "only contiguous arrays may " "be used as arguments to this operation" ) + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) + + func = elementwise.get_binary_op_kernel( + self.dtype, + other.dtype, + out.dtype, + "*", + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) - func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "*") func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, self.gpudata, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -500,17 +555,25 @@ class GPUArray: "only contiguous arrays may " "be used as arguments to this operation" ) - assert self.shape == other.shape + assert ((self.shape == other.shape == out.shape) + or ((self.shape == ()) and other.shape == out.shape) + or ((other.shape == ()) and self.shape == out.shape)) - func = elementwise.get_binary_op_kernel(self.dtype, other.dtype, out.dtype, "/") + func = elementwise.get_binary_op_kernel( + self.dtype, + other.dtype, + out.dtype, + "/", + x_is_scalar=(self.shape == ()), + y_is_scalar=(other.shape == ())) func.prepared_async_call( - self._grid, - self._block, + out._grid, + out._block, stream, self.gpudata, other.gpudata, out.gpudata, - self.mem_size, + out.mem_size, ) return out @@ -537,31 +600,35 @@ class GPUArray: if isinstance(other, GPUArray): # add another vector - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._axpbyz(1, other, 1, result) - else: + + elif np.isscalar(other): # add a scalar if other == 0: return self.copy() else: result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1, other, result) - + else: + return NotImplemented __radd__ = __add__ def __sub__(self, other): """Substract an array from an array or a scalar from an array.""" if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._axpbyz(1, other, -1, result) - else: + elif np.isscalar(other): if other == 0: return self.copy() else: # create a new array for the result result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1, -other, result) + else: + return NotImplemented def __rsub__(self, other): """Substracts an array by a scalar or an array:: @@ -590,11 +657,13 @@ class GPUArray: def __mul__(self, other): if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._elwise_multiply(other, result) - else: + elif np.isscalar(other): result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(other, 0, result) + else: + return NotImplemented def __rmul__(self, scalar): result = self._new_like_me(_get_common_dtype(self, scalar)) @@ -612,16 +681,17 @@ class GPUArray: x = self / n """ if isinstance(other, GPUArray): - result = self._new_like_me(_get_common_dtype(self, other)) + result = _get_broadcasted_binary_op_result(self, other) return self._div(other, result) - else: + elif np.isscalar(other): if other == 1: return self.copy() else: # create a new array for the result result = self._new_like_me(_get_common_dtype(self, other)) return self._axpbz(1 / other, 0, result) - + else: + return NotImplemented __truediv__ = __div__ def __rdiv__(self, other): @@ -1213,6 +1283,12 @@ class GPUArray: __ge__ = _make_binary_op(">=") __lt__ = _make_binary_op("<") __gt__ = _make_binary_op(">") + __and__ = _make_bitwise_binary_op("&") + __rand__ = __and__ + __or__ = _make_bitwise_binary_op("|") + __ror__ = __or__ + __xor__ = _make_bitwise_binary_op("^") + __rxor__ = __xor__ # }}} diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 73ec3ade..5dac40fd 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -3,6 +3,7 @@ import numpy as np import numpy.linalg as la import sys +import pytest from pycuda.tools import mark_cuda_test from pycuda.characterize import has_double_support @@ -119,13 +120,24 @@ class TestGPUArray: """Test the multiplication of two arrays.""" a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) + b = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) + c = np.array(2) a_gpu = gpuarray.to_gpu(a) - b_gpu = gpuarray.to_gpu(a) + b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) + + a_mul_b = (a_gpu * b_gpu).get() + assert (a * b == a_mul_b).all() - a_squared = (b_gpu * a_gpu).get() + b_mul_a = (b_gpu * a_gpu).get() + assert (b * a == b_mul_a).all() - assert (a * a == a_squared).all() + a_mul_c = (a_gpu * c_gpu).get() + assert (a * c == a_mul_c).all() + + b_mul_c = (b_gpu * c_gpu).get() + assert (b * c == b_mul_c).all() @mark_cuda_test def test_addition_array(self): @@ -133,9 +145,19 @@ class TestGPUArray: a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) a_gpu = gpuarray.to_gpu(a) + b = np.array(1).astype(np.float32) + b_gpu = gpuarray.to_gpu(b) a_added = (a_gpu + a_gpu).get() + a_added_scalar = (a_gpu + 1).get() + scalar_added_a = (1 + a_gpu).get() + a_gpu_pl_b_gpu = (a_gpu + b_gpu).get() + b_gpu_pl_a_gpu = (b_gpu + a_gpu).get() assert (a + a == a_added).all() + assert (a + 1 == a_added_scalar).all() + assert (1 + a == scalar_added_a).all() + assert (a + b == a_gpu_pl_b_gpu).all() + assert (b + a == b_gpu_pl_a_gpu).all() @mark_cuda_test def test_iaddition_array(self): @@ -170,14 +192,16 @@ class TestGPUArray: assert (7 + a == a_added).all() @mark_cuda_test - def test_substract_array(self): + def test_subtract_array(self): """Test the subtraction of two arrays.""" # test data a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) b = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) + c = np.array(1).astype(np.float32) a_gpu = gpuarray.to_gpu(a) b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) result = (a_gpu - b_gpu).get() assert (a - b == result).all() @@ -185,8 +209,14 @@ class TestGPUArray: result = (b_gpu - a_gpu).get() assert (b - a == result).all() + result = (a_gpu - c_gpu).get() + assert (a - c == result).all() + + result = (c_gpu - a_gpu).get() + assert (c - a == result).all() + @mark_cuda_test - def test_substract_scalar(self): + def test_subtract_scalar(self): """Test the subtraction of an array and a scalar.""" # test data @@ -221,9 +251,11 @@ class TestGPUArray: # test data a = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]).astype(np.float32) b = np.array([10, 10, 10, 10, 10, 10, 10, 10, 10, 10]).astype(np.float32) + c = np.array(2) a_gpu = gpuarray.to_gpu(a) b_gpu = gpuarray.to_gpu(b) + c_gpu = gpuarray.to_gpu(c) a_divide = (a_gpu / b_gpu).get() assert (np.abs(a / b - a_divide) < 1e-3).all() @@ -231,6 +263,12 @@ class TestGPUArray: a_divide = (b_gpu / a_gpu).get() assert (np.abs(b / a - a_divide) < 1e-3).all() + a_divide = (a_gpu / c_gpu).get() + assert (np.abs(a / c - a_divide) < 1e-3).all() + + a_divide = (c_gpu / a_gpu).get() + assert (np.abs(c / a - a_divide) < 1e-3).all() + @mark_cuda_test def test_random(self): from pycuda.curandom import rand as curand @@ -321,85 +359,57 @@ class TestGPUArray: # # Compare with scipy.stats.poisson.pmf(v - 1, v) # assert np.isclose(0.12511, tmp, atol=0.002) - @mark_cuda_test - def test_array_gt(self): - """Test whether array contents are > the other array's - contents""" - - a = np.array([5, 10]).astype(np.float32) - a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10]).astype(np.float32) - b_gpu = gpuarray.to_gpu(b) - result = (a_gpu > b_gpu).get() - assert result[0] - assert not result[1] - - @mark_cuda_test - def test_array_lt(self): - """Test whether array contents are < the other array's - contents""" - - a = np.array([5, 10]).astype(np.float32) - a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10]).astype(np.float32) - b_gpu = gpuarray.to_gpu(b) - result = (b_gpu < a_gpu).get() - assert result[0] - assert not result[1] - - @mark_cuda_test - def test_array_le(self): - """Test whether array contents are <= the other array's - contents""" - - a = np.array([5, 10, 1]).astype(np.float32) - a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10, 2]).astype(np.float32) - b_gpu = gpuarray.to_gpu(b) - result = (b_gpu <= a_gpu).get() - assert result[0] - assert result[1] - assert not result[2] - - @mark_cuda_test - def test_array_ge(self): - """Test whether array contents are >= the other array's - contents""" - - a = np.array([5, 10, 1]).astype(np.float32) - a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10, 2]).astype(np.float32) - b_gpu = gpuarray.to_gpu(b) - result = (a_gpu >= b_gpu).get() - assert result[0] - assert result[1] - assert not result[2] - - @mark_cuda_test - def test_array_eq(self): - """Test whether array contents are == the other array's - contents""" - - a = np.array([5, 10]).astype(np.float32) - a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10]).astype(np.float32) - b_gpu = gpuarray.to_gpu(b) - result = (a_gpu == b_gpu).get() - assert not result[0] - assert result[1] - - @mark_cuda_test - def test_array_ne(self): - """Test whether array contents are != the other array's - contents""" + import operator + + @pytest.mark.parametrize(("op_func", "dtype"), [ + (operator.and_, np.int32), + (operator.or_, np.int32), + (operator.xor, np.int32), + (operator.lt, np.int32), + (operator.lt, np.float32), + (operator.le, np.int32), + (operator.le, np.float32), + (operator.eq, np.int32), + (operator.eq, np.float32), + (operator.ne, np.int32), + (operator.ne, np.float32), + (operator.ge, np.int32), + (operator.ge, np.float32), + (operator.gt, np.int32), + (operator.gt, np.float32)]) + def test_array_bitwise_operations(self, op_func, dtype): + """Applies operator and roperator + on self and other array + """ - a = np.array([5, 10]).astype(np.float32) + a = np.array([5, 10]).astype(dtype) a_gpu = gpuarray.to_gpu(a) - b = np.array([2, 10]).astype(np.float32) + b = np.array([2, 10]).astype(dtype) b_gpu = gpuarray.to_gpu(b) - result = (a_gpu != b_gpu).get() - assert result[0] - assert not result[1] + scalar = np.int32(10) + c = np.array(5).astype(dtype) + c_gpu = gpuarray.to_gpu(c) + + result = op_func(a_gpu, b_gpu).get() + result_ref = op_func(a, b) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) + result = op_func(b_gpu, a_gpu).get() + result_ref = op_func(b, a) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) + + result = op_func(a_gpu, scalar).get() + result_ref = op_func(a, scalar) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) + result = op_func(scalar, a_gpu).get() + result_ref = op_func(scalar, a) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) + + result = op_func(a_gpu, c_gpu).get() + result_ref = op_func(a, c) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) + result = op_func(c_gpu, a_gpu).get() + result_ref = op_func(c, a) + np.testing.assert_allclose(result, result_ref, rtol=1e-5) @mark_cuda_test def test_nan_arithmetic(self): @@ -1297,6 +1307,19 @@ class TestGPUArray: assert new_z.dtype == np.complex64 assert new_z.shape == arr.shape + @pytest.mark.parametrize(("op_func", "dtype"), [ + # float only + (operator.and_, np.float32), + (operator.and_, np.float64), + (operator.or_, np.float32), + (operator.or_, np.float64), + (operator.xor, np.float32), + (operator.xor, np.float64)]) + def test_bitwise_ops_raise_on_float_inputs(self, op_func, dtype): + a = gpuarray.to_gpu(np.random.rand(10).astype(dtype)) + with pytest.raises(TypeError): + op_func(a, a) + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. -- GitLab