From 7f6649cda6d974187aabd7ee616e168ffa9ce274 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 10 Feb 2022 16:28:07 -0600 Subject: [PATCH] Fix, test integer ary/ary division to match numpy --- pyopencl/elementwise.py | 5 +++++ test/test_array.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 9f51c299..87750403 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -645,6 +645,11 @@ def get_divide_kernel(context, dtype_x, dtype_y, dtype_z, x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x) if y_is_complex and dtype_y != dtype_z: y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y) + else: + if dtype_x != dtype_z: + x = f"({dtype_to_ctype(dtype_z)}) ({x})" + if dtype_y != dtype_z: + y = f"({dtype_to_ctype(dtype_z)}) ({y})" if x_is_complex and y_is_complex: xoy = "{}_divide({}, {})".format(complex_dtype_to_name(dtype_z), x, y) diff --git a/test/test_array.py b/test/test_array.py index e2834179..45fdaa27 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -40,6 +40,7 @@ from pyopencl.tools import ( # noqa from pyopencl.characterize import has_double_support, has_struct_arg_count_bug from pyopencl.clrandom import RanluxGenerator, PhiloxGenerator, ThreefryGenerator +import operator _PYPY = cl._PYPY @@ -343,6 +344,22 @@ def test_custom_type_take_put(ctx_factory): # {{{ operators +@pytest.mark.parametrize("dtype", [np.int8, np.int32, np.int64, np.float32]) +# FIXME Implement florodiv +#@pytest.mark.parametrize("op", [operator.truediv, operator.floordiv]) +@pytest.mark.parametrize("op", [operator.truediv]) +def test_div_type_matches_numpy(ctx_factory, dtype, op): + context = ctx_factory() + queue = cl.CommandQueue(context) + + a = cl_array.arange(queue, 10, dtype=dtype) + 1 + res = op(4*a, 3*a) + a_np = a.get() + res_np = op(4*a_np, 3*a_np) + assert res_np.dtype == res.dtype + assert np.allclose(res_np, res.get()) + + def test_rmul_yields_right_type(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context) -- GitLab