From a8d91e12d41b8ca07a2eb108a35f0ffd4c3f0f9d Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 2 May 2012 17:32:01 -0400
Subject: [PATCH] Add code to test against Hellskitchen bessel functions.

---
 test/test_clmath.py | 46 +++++++++++++++++++++++++++++++++++++++------
 1 file changed, 40 insertions(+), 6 deletions(-)

diff --git a/test/test_clmath.py b/test/test_clmath.py
index 38641380..72a6695a 100644
--- a/test/test_clmath.py
+++ b/test/test_clmath.py
@@ -187,7 +187,6 @@ def test_bessel_j(ctx_factory):
         from py.test import skip
         skip("scipy not present--cannot test Bessel function")
 
-    a = np.logspace(-5, 5, 10**6)
 
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
@@ -196,15 +195,47 @@ def test_bessel_j(ctx_factory):
         from py.test import skip
         skip("no double precision support--cannot test bessel function")
 
+    nterms = 30
+
+    try:
+        from hellskitchen._native import jfuns2d
+    except ImportError:
+        use_hellskitchen = False
+    else:
+        use_hellskitchen = True
+
+    if use_hellskitchen:
+        a = np.logspace(-3, 3, 10**6)
+    else:
+        a = np.logspace(-5, 5, 10**6)
+
+    if use_hellskitchen:
+        hellskitchen_result = np.empty((len(a), nterms), dtype=np.complex128)
+        for i, a_i in enumerate(a):
+            if i % 10000 == 0:
+                print "%.1f %%" % (100 * i/len(a))
+            ier, fjs, _, _ = jfuns2d(nterms, a_i, 1, 0, 10000)
+            hellskitchen_result[i] = fjs[:nterms]
+        assert ier == 0
+
     a_dev = cl_array.to_device(queue, a)
 
-    for n in range(0, 30):
+    for n in range(0, nterms):
         cl_bessel = clmath.bessel_jn(n, a_dev).get()
         scipy_bessel = spec.jn(n, a)
 
-        error = np.max(np.abs(cl_bessel-scipy_bessel))
-        print(n, error)
-        assert error < 1e-10
+        error_scipy = np.max(np.abs(cl_bessel-scipy_bessel))
+        assert error_scipy < 1e-10, error_scipy
+
+        if use_hellskitchen:
+            hk_bessel = hellskitchen_result[:, n]
+            error_hk = np.max(np.abs(cl_bessel-hk_bessel))
+            assert error_hk < 1e-10, error_hk
+            error_hk_scipy = np.max(np.abs(scipy_bessel-hk_bessel))
+            print(n, error_scipy, error_hk, error_hk_scipy)
+        else:
+            print(n, error_scipy)
+
         assert not np.isnan(cl_bessel).any()
 
         if 0 and n == 15:
@@ -212,7 +243,10 @@ def test_bessel_j(ctx_factory):
             #pt.plot(scipy_bessel)
             #pt.plot(cl_bessel)
 
-            pt.loglog(a, np.abs(cl_bessel-scipy_bessel))
+            pt.loglog(a, np.abs(cl_bessel-scipy_bessel), label="vs scipy")
+            if use_hellskitchen:
+                pt.loglog(a, np.abs(cl_bessel-hk_bessel), label="vs hellskitchen")
+            pt.legend()
             pt.show()
 
 
-- 
GitLab