From 56a2ea5607f3f10f14466cd34fb8290947bd17d4 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 23 Jan 2012 01:25:07 -0500 Subject: [PATCH] Make complex division more SIMD-friendly. --- src/cl/pyopencl-complex.h | 36 +++++++++++++++++++----------------- test/test_array.py | 4 +++- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/cl/pyopencl-complex.h b/src/cl/pyopencl-complex.h index 26bfbde3..8e4093f2 100644 --- a/src/cl/pyopencl-complex.h +++ b/src/cl/pyopencl-complex.h @@ -45,10 +45,7 @@ \ TP TPROOT##_rdivide(REAL_TP z1, TP z2) \ { \ - REAL_TP ar = z2.x >= 0 ? z2.x : -z2.x; \ - REAL_TP ai = z2.y >= 0 ? z2.y : -z2.y; \ - \ - if (ar <= ai) { \ + if (fabs(z2.x) <= fabs(z2.y)) { \ REAL_TP ratio = z2.x / z2.y; \ REAL_TP denom = z2.y * (1 + ratio * ratio); \ return (TP)((z1 * ratio) / denom, - z1 / denom); \ @@ -62,23 +59,28 @@ \ TP TPROOT##_divide(TP z1, TP z2) \ { \ - REAL_TP ar = z2.x >= 0 ? z2.x : -z2.x; \ - REAL_TP ai = z2.y >= 0 ? z2.y : -z2.y; \ + REAL_TP ratio, denom, a, b, c, d; \ \ - if (ar <= ai) { \ - REAL_TP ratio = z2.x / z2.y; \ - REAL_TP denom = z2.y * (1 + ratio * ratio); \ - return (TP)( \ - (z1.x * ratio + z1.y) / denom, \ - (z1.y * ratio - z1.x) / denom); \ + if (fabs(z2.x) <= fabs(z2.y)) { \ + ratio = z2.x / z2.y; \ + denom = z2.y; \ + a = z1.y; \ + b = z1.x; \ + c = -z1.x; \ + d = z1.y; \ } \ else { \ - REAL_TP ratio = z2.y / z2.x; \ - REAL_TP denom = z2.x * (1 + ratio * ratio); \ - return (TP)( \ - (z1.x + z1.y * ratio) / denom, \ - (z1.y - z1.x * ratio) / denom); \ + ratio = z2.y / z2.x; \ + denom = z2.x; \ + a = z1.x; \ + b = z1.y; \ + c = z1.y; \ + d = -z1.x; \ } \ + denom *= (1 + ratio * ratio); \ + return (TP)( \ + (a + b * ratio) / denom, \ + (c + d * ratio) / denom); \ } \ \ TP TPROOT##_pow(TP a, TP b) \ diff --git a/test/test_array.py b/test/test_array.py index 0f6ffebd..78ea5535 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -166,7 +166,9 @@ def test_mix_complex(ctx_factory): dev_result = dev_result.astype(host_result.dtype) - correct = np.allclose(host_result, dev_result) + err = la.norm(host_result-dev_result)/la.norm(host_result) + print err + correct = err < 1e-5 if not correct: print host_result print dev_result -- GitLab