From de6327bb21b5f0361fb606a9d7932d7c57187ef9 Mon Sep 17 00:00:00 2001 From: Yichao Yu <yyc1992@gmail.com> Date: Sun, 18 May 2014 10:20:57 -0400 Subject: [PATCH] handle different input types in clmath --- pyopencl/clmath.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/pyopencl/clmath.py b/pyopencl/clmath.py index d3c6d0f7..12898529 100644 --- a/pyopencl/clmath.py +++ b/pyopencl/clmath.py @@ -23,6 +23,7 @@ THE SOFTWARE. import pyopencl.array as cl_array import pyopencl.elementwise as elementwise from pyopencl.array import _get_common_dtype +import numpy as np def _make_unary_array_func(name): @@ -118,13 +119,15 @@ floor = _make_unary_array_func("floor") @cl_array.elwise_kernel_runner def _fmod(result, arg, mod): - return elementwise.get_fmod_kernel(result.context) + return elementwise.get_fmod_kernel(result.context, result.dtype, + arg.dtype, mod.dtype) def fmod(arg, mod, queue=None): """Return the floating point remainder of the division `arg/mod`, for each element in `arg` and `mod`.""" - result = arg._new_like_me(queue=queue) + queue = (queue or arg.queue) or mod.queue + result = arg._new_like_me(_get_common_dtype(arg, mod, queue)) _fmod(result, arg, mod, queue=queue) return result @@ -133,7 +136,8 @@ def fmod(arg, mod, queue=None): @cl_array.elwise_kernel_runner def _frexp(sig, expt, arg): - return elementwise.get_frexp_kernel(sig.context) + return elementwise.get_frexp_kernel(sig.context, sig.dtype, + expt.dtype, arg.dtype) def frexp(arg, queue=None): @@ -153,7 +157,8 @@ ilogb = _make_unary_array_func("ilogb") @cl_array.elwise_kernel_runner def _ldexp(result, sig, exp): - return elementwise.get_ldexp_kernel(result.context) + return elementwise.get_ldexp_kernel(result.context, result.dtype, + sig.dtype, exp.dtype) def ldexp(significand, exponent, queue=None): @@ -181,7 +186,8 @@ logb = _make_unary_array_func("logb") @cl_array.elwise_kernel_runner def _modf(intpart, fracpart, arg): - return elementwise.get_modf_kernel(intpart.context) + return elementwise.get_modf_kernel(intpart.context, intpart.dtype, + fracpart.dtype, arg.dtype) def modf(arg, queue=None): @@ -223,13 +229,15 @@ trunc = _make_unary_array_func("trunc") # TODO: table 6.12, clamp et al @cl_array.elwise_kernel_runner -def _bessel_jn(result, sig, exp): - return elementwise.get_bessel_kernel(result.context, "j") +def _bessel_jn(result, n, x): + return elementwise.get_bessel_kernel(result.context, "j", result.dtype, + np.dtype(type(n)), x.dtype) @cl_array.elwise_kernel_runner -def _bessel_yn(result, sig, exp): - return elementwise.get_bessel_kernel(result.context, "y") +def _bessel_yn(result, n, x): + return elementwise.get_bessel_kernel(result.context, "y", result.dtype, + np.dtype(type(n)), x.dtype) def bessel_jn(n, x, queue=None): -- GitLab