From 56a2ea5607f3f10f14466cd34fb8290947bd17d4 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 23 Jan 2012 01:25:07 -0500
Subject: [PATCH] Make complex division more SIMD-friendly.

---
 src/cl/pyopencl-complex.h | 36 +++++++++++++++++++-----------------
 test/test_array.py        |  4 +++-
 2 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/src/cl/pyopencl-complex.h b/src/cl/pyopencl-complex.h
index 26bfbde3..8e4093f2 100644
--- a/src/cl/pyopencl-complex.h
+++ b/src/cl/pyopencl-complex.h
@@ -45,10 +45,7 @@
   \
   TP TPROOT##_rdivide(REAL_TP z1, TP z2) \
   { \
-    REAL_TP ar = z2.x >= 0 ? z2.x : -z2.x; \
-    REAL_TP ai = z2.y >= 0 ? z2.y : -z2.y; \
-    \
-    if (ar <= ai) { \
+    if (fabs(z2.x) <= fabs(z2.y)) { \
       REAL_TP ratio = z2.x / z2.y; \
       REAL_TP denom = z2.y * (1 + ratio * ratio); \
       return (TP)((z1 * ratio) / denom, - z1 / denom); \
@@ -62,23 +59,28 @@
   \
   TP TPROOT##_divide(TP z1, TP z2) \
   { \
-    REAL_TP ar = z2.x >= 0 ? z2.x : -z2.x; \
-    REAL_TP ai = z2.y >= 0 ? z2.y : -z2.y; \
+    REAL_TP ratio, denom, a, b, c, d; \
     \
-    if (ar <= ai) { \
-      REAL_TP ratio = z2.x / z2.y; \
-      REAL_TP denom = z2.y * (1 + ratio * ratio); \
-      return (TP)( \
-         (z1.x * ratio + z1.y) / denom, \
-         (z1.y * ratio - z1.x) / denom); \
+    if (fabs(z2.x) <= fabs(z2.y)) { \
+      ratio = z2.x / z2.y; \
+      denom = z2.y; \
+      a = z1.y; \
+      b = z1.x; \
+      c = -z1.x; \
+      d = z1.y; \
     } \
     else { \
-      REAL_TP ratio = z2.y / z2.x; \
-      REAL_TP denom = z2.x * (1 + ratio * ratio); \
-      return (TP)( \
-         (z1.x + z1.y * ratio) / denom, \
-         (z1.y - z1.x * ratio) / denom); \
+      ratio = z2.y / z2.x; \
+      denom = z2.x; \
+      a = z1.x; \
+      b = z1.y; \
+      c = z1.y; \
+      d = -z1.x; \
     } \
+    denom *= (1 + ratio * ratio); \
+    return (TP)( \
+       (a + b * ratio) / denom, \
+       (c + d * ratio) / denom); \
   } \
   \
   TP TPROOT##_pow(TP a, TP b) \
diff --git a/test/test_array.py b/test/test_array.py
index 0f6ffebd..78ea5535 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -166,7 +166,9 @@ def test_mix_complex(ctx_factory):
 
                         dev_result = dev_result.astype(host_result.dtype)
 
-                    correct = np.allclose(host_result, dev_result)
+                    err = la.norm(host_result-dev_result)/la.norm(host_result)
+                    print err
+                    correct = err < 1e-5
                     if not correct:
                         print host_result
                         print dev_result
-- 
GitLab