From a8d91e12d41b8ca07a2eb108a35f0ffd4c3f0f9d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 2 May 2012 17:32:01 -0400 Subject: [PATCH] Add code to test against Hellskitchen bessel functions. --- test/test_clmath.py | 46 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/test/test_clmath.py b/test/test_clmath.py index 38641380..72a6695a 100644 --- a/test/test_clmath.py +++ b/test/test_clmath.py @@ -187,7 +187,6 @@ def test_bessel_j(ctx_factory): from py.test import skip skip("scipy not present--cannot test Bessel function") - a = np.logspace(-5, 5, 10**6) ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -196,15 +195,47 @@ def test_bessel_j(ctx_factory): from py.test import skip skip("no double precision support--cannot test bessel function") + nterms = 30 + + try: + from hellskitchen._native import jfuns2d + except ImportError: + use_hellskitchen = False + else: + use_hellskitchen = True + + if use_hellskitchen: + a = np.logspace(-3, 3, 10**6) + else: + a = np.logspace(-5, 5, 10**6) + + if use_hellskitchen: + hellskitchen_result = np.empty((len(a), nterms), dtype=np.complex128) + for i, a_i in enumerate(a): + if i % 10000 == 0: + print "%.1f %%" % (100 * i/len(a)) + ier, fjs, _, _ = jfuns2d(nterms, a_i, 1, 0, 10000) + hellskitchen_result[i] = fjs[:nterms] + assert ier == 0 + a_dev = cl_array.to_device(queue, a) - for n in range(0, 30): + for n in range(0, nterms): cl_bessel = clmath.bessel_jn(n, a_dev).get() scipy_bessel = spec.jn(n, a) - error = np.max(np.abs(cl_bessel-scipy_bessel)) - print(n, error) - assert error < 1e-10 + error_scipy = np.max(np.abs(cl_bessel-scipy_bessel)) + assert error_scipy < 1e-10, error_scipy + + if use_hellskitchen: + hk_bessel = hellskitchen_result[:, n] + error_hk = np.max(np.abs(cl_bessel-hk_bessel)) + assert error_hk < 1e-10, error_hk + error_hk_scipy = np.max(np.abs(scipy_bessel-hk_bessel)) + print(n, error_scipy, error_hk, error_hk_scipy) + else: + print(n, error_scipy) + assert not np.isnan(cl_bessel).any() if 0 and n == 15: @@ -212,7 +243,10 @@ def test_bessel_j(ctx_factory): #pt.plot(scipy_bessel) #pt.plot(cl_bessel) - pt.loglog(a, np.abs(cl_bessel-scipy_bessel)) + pt.loglog(a, np.abs(cl_bessel-scipy_bessel), label="vs scipy") + if use_hellskitchen: + pt.loglog(a, np.abs(cl_bessel-hk_bessel), label="vs hellskitchen") + pt.legend() pt.show() -- GitLab