diff --git a/pyopencl/clmath.py b/pyopencl/clmath.py index 128985297721b3cf1a9ba35effcb5c4010dd7053..1b41ce67d0ce9b7543c6f76944edf29c1a082ff0 100644 --- a/pyopencl/clmath.py +++ b/pyopencl/clmath.py @@ -57,14 +57,14 @@ asinpi = _make_unary_array_func("asinpi") @cl_array.elwise_kernel_runner def _atan2(result, arg1, arg2): - return elementwise.get_binary_func_kernel(result.context, "atan2", - arg1.dtype, arg2.dtype, result.dtype) + return elementwise.get_float_binary_func_kernel( + result.context, "atan2", arg1.dtype, arg2.dtype, result.dtype) @cl_array.elwise_kernel_runner def _atan2pi(result, arg1, arg2): - return elementwise.get_binary_func_kernel(result.context, "atan2pi", - arg1.dtype, arg2.dtype, result.dtype) + return elementwise.get_float_binary_func_kernel( + result.context, "atan2pi", arg1.dtype, arg2.dtype, result.dtype) atan = _make_unary_array_func("atan") diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 8d5836815d6610e29c824880424ef485c5eff47b..03c37024489a57b70860f71f8d0a99fd1edf7159 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -837,11 +837,32 @@ def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype, preamble=preamble) +@context_dependent_memoize +def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype, + out_dtype, preamble="", name=None): + if (np.array(0, x_dtype) * np.array(0, y_dtype)).itemsize > 4: + arg_type = 'double' + preamble = """ + #pragma OPENCL EXTENSION cl_khr_fp64: enable + #define PYOPENCL_DEFINE_CDOUBLE + """ + preamble + else: + arg_type = 'float' + return get_elwise_kernel(context, [ + VectorArg(out_dtype, "z", with_offset=True), + VectorArg(x_dtype, "x", with_offset=True), + VectorArg(y_dtype, "y", with_offset=True), + ], + "z[i] = %s((%s)x[i], (%s)y[i])" % (func_name, arg_type, arg_type), + name="%s_kernel" % func_name if name is None else name, + preamble=preamble) + + @context_dependent_memoize def get_fmod_kernel(context, out_dtype=np.float32, arg_dtype=np.float32, mod_dtype=np.float32): - return get_binary_func_kernel(context, 'fmod', arg_dtype, - mod_dtype, out_dtype) + return get_float_binary_func_kernel(context, 'fmod', arg_dtype, + mod_dtype, out_dtype) @context_dependent_memoize diff --git a/test/test_clmath.py b/test/test_clmath.py index 3091e94273dff7040fb1c4446780420802278342..6ebbe46bcb1f49ed70e1797c003013fe43d79147 100644 --- a/test/test_clmath.py +++ b/test/test_clmath.py @@ -127,6 +127,40 @@ if have_cl(): test_tanh = make_unary_function_test("tanh", (-3, 3), 2e-6, use_complex=True) +def test_atan2(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + for s in sizes: + a = (cl_array.arange(queue, s, dtype=np.float32) - s / 2) / 100 + a2 = (s / 2 - 1 - cl_array.arange(queue, s, dtype=np.float32)) / 100 + b = clmath.atan2(a, a2) + + a = a.get() + a2 = a2.get() + b = b.get() + + for i in range(s): + assert abs(math.atan2(a[i], a2[i]) - b[i]) < 1e-6 + + +def test_atan2pi(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + for s in sizes: + a = (cl_array.arange(queue, s, dtype=np.float32) - s / 2) / 100 + a2 = (s / 2 - 1 - cl_array.arange(queue, s, dtype=np.float32)) / 100 + b = clmath.atan2pi(a, a2) + + a = a.get() + a2 = a2.get() + b = b.get() + + for i in range(s): + assert abs(math.atan2(a[i], a2[i]) / math.pi - b[i]) < 1e-6 + + def test_fmod(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context)