Skip to content
Snippets Groups Projects
Commit cde5b8b7 authored by Isuru Fernando's avatar Isuru Fernando Committed by Andreas Klöckner
Browse files

rename to muladd and add a test

parent 466d265f
No related branches found
No related tags found
No related merge requests found
......@@ -77,7 +77,7 @@
; \
} \
\
inline TP TPROOT##_fma(TP c, TP a, TP b) \
inline TP TPROOT##_muladd(TP a, TP b, TP c) \
{ \
return TPROOT##_new( \
(c.real + a.real*b.real) - a.imag*b.imag, \
......
......@@ -447,6 +447,68 @@ def test_hankel_01_complex(ctx_factory, ref_src):
pt.show()
@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_complex_muladd(ctx_factory, dtype):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
if dtype == np.complex128 and not has_double_support(ctx.devices[0]):
from pytest import skip
skip("no double precision support")
if dtype == np.complex64:
real_type = np.float32
real_type_name = "float"
else:
real_type = np.float64
real_type_name = "double"
rng = np.random.default_rng(seed=11)
n = 100
arrs = [rng.random(n, dtype=real_type) + 1j*rng.random(n, dtype=real_type)
for i in range(3)]
arrs = [arr.astype(dtype) for arr in arrs]
arrs_dev = [cl_array.to_device(queue, arr) for arr in arrs]
prg_str = """
#if __OPENCL_C_VERSION__ < 120
#pragma OPENCL EXTENSION cl_khr_fp64: enable
#endif
#define PYOPENCL_DEFINE_CDOUBLE
#include <pyopencl-complex.h>
__kernel void foo(
__global const c{real_type_name}_t *a,
__global const c{real_type_name}_t *b,
__global const c{real_type_name}_t *c,
__global c{real_type_name}_t *res
)
{{
int gid = get_global_id(0);
res[gid] = c{real_type_name}_muladd(a[gid], b[gid], c[gid]);
}}
""".format(real_type_name=real_type_name)
prg = cl.Program(ctx, prg_str).build()
knl = prg.foo
result_dev = cl_array.empty_like(arrs_dev[0])
knl(queue, (n,), None, arrs_dev[0].data,
arrs_dev[1].data, arrs_dev[2].data, result_dev.data)
ref = arrs[0] * arrs[1] + arrs[2]
rel_err = np.abs(result_dev.get() - ref)/np.abs(ref)
if dtype == np.complex64:
assert np.max(rel_err) < 1e-6
else:
assert np.max(rel_err) < 1e-12
def test_outoforderqueue_clmath(ctx_factory):
context = ctx_factory()
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment