From 670c7790881709624ba4b5604ed45cce3ec95494 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 9 Sep 2020 16:56:43 -0500 Subject: [PATCH] Test, correctly handle int64 - uint32 (closes gh-355) --- pyopencl/array.py | 4 ++-- pyopencl/elementwise.py | 17 +++++++++-------- test/test_array.py | 14 +++++++++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index b06dd706..8d5cad90 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1008,7 +1008,7 @@ class Array: result.add_event( self._axpbyz(result, self.dtype.type(1), self, - other.dtype.type(-1), other)) + result.dtype.type(-1), other)) return result else: @@ -1031,7 +1031,7 @@ class Array: # other must be a scalar result = self._new_like_me(common_dtype) result.add_event( - self._axpbz(result, self.dtype.type(-1), self, + self._axpbz(result, result.dtype.type(-1), self, common_dtype.type(other))) return result diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 357aa2bb..9e6c762a 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -503,23 +503,24 @@ def real_dtype(dtype): @context_dependent_memoize def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z): - ax = "a*x[i]" - by = "b*y[i]" + result_t = dtype_to_ctype(dtype_z) x_is_complex = dtype_x.kind == "c" y_is_complex = dtype_y.kind == "c" if x_is_complex: ax = "%s_mul(a, x[i])" % complex_dtype_to_name(dtype_x) + elif not x_is_complex and y_is_complex: + ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax) + else: + ax = f"a*(({result_t}) x[i])" if y_is_complex: by = "%s_mul(b, y[i])" % complex_dtype_to_name(dtype_y) - - if x_is_complex and not y_is_complex: + elif x_is_complex and not y_is_complex: by = "{}_fromreal({})".format(complex_dtype_to_name(dtype_x), by) - - if not x_is_complex and y_is_complex: - ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_y), ax) + else: + by = f"b*(({result_t}) y[i])" if x_is_complex or y_is_complex: result = ( @@ -532,7 +533,7 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z): result = f"{ax} + {by}" return get_elwise_kernel(context, - "{tp_z} *z, {tp_x} a, {tp_x} *x, {tp_y} b, {tp_y} *y".format( + "{tp_z} *z, {tp_z} a, {tp_x} *x, {tp_z} b, {tp_y} *y".format( tp_x=dtype_to_ctype(dtype_x), tp_y=dtype_to_ctype(dtype_y), tp_z=dtype_to_ctype(dtype_z), diff --git a/test/test_array.py b/test/test_array.py index 39f8fd74..d1777237 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -426,12 +426,20 @@ def test_addition_scalar(ctx_factory): assert (7 + a == a_added).all() -def test_substract_array(ctx_factory): +@pytest.mark.parametrize(("dtype_a", "dtype_b"), + [ + (np.float32, np.float32), + (np.float32, np.int32), + (np.int32, np.int32), + (np.int64, np.int32), + (np.int64, np.uint32), + ]) +def test_subtract_array(ctx_factory, dtype_a, dtype_b): """Test the substraction of two arrays.""" #test data - a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32) + a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(dtype_a) b = np.array([10, 20, 30, 40, 50, - 60, 70, 80, 90, 100]).astype(np.float32) + 60, 70, 80, 90, 100]).astype(dtype_b) context = ctx_factory() queue = cl.CommandQueue(context) -- GitLab