From de6327bb21b5f0361fb606a9d7932d7c57187ef9 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Sun, 18 May 2014 10:20:57 -0400
Subject: [PATCH] handle different input types in clmath

---
 pyopencl/clmath.py | 26 +++++++++++++++++---------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/pyopencl/clmath.py b/pyopencl/clmath.py
index d3c6d0f7..12898529 100644
--- a/pyopencl/clmath.py
+++ b/pyopencl/clmath.py
@@ -23,6 +23,7 @@ THE SOFTWARE.
 import pyopencl.array as cl_array
 import pyopencl.elementwise as elementwise
 from pyopencl.array import _get_common_dtype
+import numpy as np
 
 
 def _make_unary_array_func(name):
@@ -118,13 +119,15 @@ floor = _make_unary_array_func("floor")
 
 @cl_array.elwise_kernel_runner
 def _fmod(result, arg, mod):
-    return elementwise.get_fmod_kernel(result.context)
+    return elementwise.get_fmod_kernel(result.context, result.dtype,
+                                       arg.dtype, mod.dtype)
 
 
 def fmod(arg, mod, queue=None):
     """Return the floating point remainder of the division `arg/mod`,
     for each element in `arg` and `mod`."""
-    result = arg._new_like_me(queue=queue)
+    queue = (queue or arg.queue) or mod.queue
+    result = arg._new_like_me(_get_common_dtype(arg, mod, queue))
     _fmod(result, arg, mod, queue=queue)
     return result
 
@@ -133,7 +136,8 @@ def fmod(arg, mod, queue=None):
 
 @cl_array.elwise_kernel_runner
 def _frexp(sig, expt, arg):
-    return elementwise.get_frexp_kernel(sig.context)
+    return elementwise.get_frexp_kernel(sig.context, sig.dtype,
+                                        expt.dtype, arg.dtype)
 
 
 def frexp(arg, queue=None):
@@ -153,7 +157,8 @@ ilogb = _make_unary_array_func("ilogb")
 
 @cl_array.elwise_kernel_runner
 def _ldexp(result, sig, exp):
-    return elementwise.get_ldexp_kernel(result.context)
+    return elementwise.get_ldexp_kernel(result.context, result.dtype,
+                                        sig.dtype, exp.dtype)
 
 
 def ldexp(significand, exponent, queue=None):
@@ -181,7 +186,8 @@ logb = _make_unary_array_func("logb")
 
 @cl_array.elwise_kernel_runner
 def _modf(intpart, fracpart, arg):
-    return elementwise.get_modf_kernel(intpart.context)
+    return elementwise.get_modf_kernel(intpart.context, intpart.dtype,
+                                       fracpart.dtype, arg.dtype)
 
 
 def modf(arg, queue=None):
@@ -223,13 +229,15 @@ trunc = _make_unary_array_func("trunc")
 # TODO: table 6.12, clamp et al
 
 @cl_array.elwise_kernel_runner
-def _bessel_jn(result, sig, exp):
-    return elementwise.get_bessel_kernel(result.context, "j")
+def _bessel_jn(result, n, x):
+    return elementwise.get_bessel_kernel(result.context, "j", result.dtype,
+                                         np.dtype(type(n)), x.dtype)
 
 
 @cl_array.elwise_kernel_runner
-def _bessel_yn(result, sig, exp):
-    return elementwise.get_bessel_kernel(result.context, "y")
+def _bessel_yn(result, n, x):
+    return elementwise.get_bessel_kernel(result.context, "y", result.dtype,
+                                         np.dtype(type(n)), x.dtype)
 
 
 def bessel_jn(n, x, queue=None):
-- 
GitLab