diff --git a/pyopencl/cl/pyopencl-complex.h b/pyopencl/cl/pyopencl-complex.h index beee8f6c38cdb4667fca632a912c82a52b784fbd..c33cbca0d334610956ab8cf99e6f1b869c656739 100644 --- a/pyopencl/cl/pyopencl-complex.h +++ b/pyopencl/cl/pyopencl-complex.h @@ -77,7 +77,7 @@ ; \ } \ \ - inline TP TPROOT##_fma(TP c, TP a, TP b) \ + inline TP TPROOT##_muladd(TP a, TP b, TP c) \ { \ return TPROOT##_new( \ (c.real + a.real*b.real) - a.imag*b.imag, \ diff --git a/test/test_clmath.py b/test/test_clmath.py index e5032dc148a731f0acfb63497f7a00d8656253a7..b640e0bbf979c9ed84139184ad5165513db48291 100644 --- a/test/test_clmath.py +++ b/test/test_clmath.py @@ -447,6 +447,68 @@ def test_hankel_01_complex(ctx_factory, ref_src): pt.show() +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) +def test_complex_muladd(ctx_factory, dtype): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + if dtype == np.complex128 and not has_double_support(ctx.devices[0]): + from pytest import skip + skip("no double precision support") + + if dtype == np.complex64: + real_type = np.float32 + real_type_name = "float" + else: + real_type = np.float64 + real_type_name = "double" + + rng = np.random.default_rng(seed=11) + n = 100 + + arrs = [rng.random(n, dtype=real_type) + 1j*rng.random(n, dtype=real_type) + for i in range(3)] + arrs = [arr.astype(dtype) for arr in arrs] + + arrs_dev = [cl_array.to_device(queue, arr) for arr in arrs] + + prg_str = """ + #if __OPENCL_C_VERSION__ < 120 + #pragma OPENCL EXTENSION cl_khr_fp64: enable + #endif + #define PYOPENCL_DEFINE_CDOUBLE + + #include + + __kernel void foo( + __global const c{real_type_name}_t *a, + __global const c{real_type_name}_t *b, + __global const c{real_type_name}_t *c, + __global c{real_type_name}_t *res + ) + {{ + int gid = get_global_id(0); + res[gid] = c{real_type_name}_muladd(a[gid], b[gid], c[gid]); + }} + """.format(real_type_name=real_type_name) + + prg = cl.Program(ctx, prg_str).build() + knl = prg.foo + + result_dev = cl_array.empty_like(arrs_dev[0]) + knl(queue, (n,), None, arrs_dev[0].data, + arrs_dev[1].data, arrs_dev[2].data, result_dev.data) + + ref = arrs[0] * arrs[1] + arrs[2] + + rel_err = np.abs(result_dev.get() - ref)/np.abs(ref) + + if dtype == np.complex64: + assert np.max(rel_err) < 1e-6 + else: + assert np.max(rel_err) < 1e-12 + + def test_outoforderqueue_clmath(ctx_factory): context = ctx_factory() try: