From 7f6649cda6d974187aabd7ee616e168ffa9ce274 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 10 Feb 2022 16:28:07 -0600
Subject: [PATCH] Fix, test integer ary/ary division to match numpy

---
 pyopencl/elementwise.py |  5 +++++
 test/test_array.py      | 17 +++++++++++++++++
 2 files changed, 22 insertions(+)

diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 9f51c299..87750403 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -645,6 +645,11 @@ def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
             x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
         if y_is_complex and dtype_y != dtype_z:
             y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y)
+    else:
+        if dtype_x != dtype_z:
+            x = f"({dtype_to_ctype(dtype_z)}) ({x})"
+        if dtype_y != dtype_z:
+            y = f"({dtype_to_ctype(dtype_z)}) ({y})"
 
     if x_is_complex and y_is_complex:
         xoy = "{}_divide({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
diff --git a/test/test_array.py b/test/test_array.py
index e2834179..45fdaa27 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -40,6 +40,7 @@ from pyopencl.tools import (  # noqa
 from pyopencl.characterize import has_double_support, has_struct_arg_count_bug
 
 from pyopencl.clrandom import RanluxGenerator, PhiloxGenerator, ThreefryGenerator
+import operator
 
 _PYPY = cl._PYPY
 
@@ -343,6 +344,22 @@ def test_custom_type_take_put(ctx_factory):
 
 # {{{ operators
 
+@pytest.mark.parametrize("dtype", [np.int8, np.int32, np.int64, np.float32])
+# FIXME Implement florodiv
+#@pytest.mark.parametrize("op", [operator.truediv, operator.floordiv])
+@pytest.mark.parametrize("op", [operator.truediv])
+def test_div_type_matches_numpy(ctx_factory, dtype, op):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    a = cl_array.arange(queue, 10, dtype=dtype) + 1
+    res = op(4*a, 3*a)
+    a_np = a.get()
+    res_np = op(4*a_np, 3*a_np)
+    assert res_np.dtype == res.dtype
+    assert np.allclose(res_np, res.get())
+
+
 def test_rmul_yields_right_type(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
-- 
GitLab