diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 9f51c29934777518cc6221d8b29d377986191a9c..877504038a3177b70068a9ed0a54540d6d7b80c1 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -645,6 +645,11 @@ def get_divide_kernel(context, dtype_x, dtype_y, dtype_z, x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x) if y_is_complex and dtype_y != dtype_z: y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y) + else: + if dtype_x != dtype_z: + x = f"({dtype_to_ctype(dtype_z)}) ({x})" + if dtype_y != dtype_z: + y = f"({dtype_to_ctype(dtype_z)}) ({y})" if x_is_complex and y_is_complex: xoy = "{}_divide({}, {})".format(complex_dtype_to_name(dtype_z), x, y) diff --git a/test/test_array.py b/test/test_array.py index e283417996d118ec950cfff2d06f311530f57a66..45fdaa27cb8fd5a0f3ebc8a958b4dc6323d1c6af 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -40,6 +40,7 @@ from pyopencl.tools import ( # noqa from pyopencl.characterize import has_double_support, has_struct_arg_count_bug from pyopencl.clrandom import RanluxGenerator, PhiloxGenerator, ThreefryGenerator +import operator _PYPY = cl._PYPY @@ -343,6 +344,22 @@ def test_custom_type_take_put(ctx_factory): # {{{ operators +@pytest.mark.parametrize("dtype", [np.int8, np.int32, np.int64, np.float32]) +# FIXME Implement florodiv +#@pytest.mark.parametrize("op", [operator.truediv, operator.floordiv]) +@pytest.mark.parametrize("op", [operator.truediv]) +def test_div_type_matches_numpy(ctx_factory, dtype, op): + context = ctx_factory() + queue = cl.CommandQueue(context) + + a = cl_array.arange(queue, 10, dtype=dtype) + 1 + res = op(4*a, 3*a) + a_np = a.get() + res_np = op(4*a_np, 3*a_np) + assert res_np.dtype == res.dtype + assert np.allclose(res_np, res.get()) + + def test_rmul_yields_right_type(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context)