From 1ea4c7964585fe581170c8ef4d29287f60aa79c9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 21 Mar 2013 16:11:08 -0400 Subject: [PATCH] Use pyfmmlib as reference for Bessel functions instead of hellskitchen. --- test/test_clmath.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/test/test_clmath.py b/test/test_clmath.py index c2d25490..6164ebe8 100644 --- a/test/test_clmath.py +++ b/test/test_clmath.py @@ -220,19 +220,21 @@ def test_bessel(ctx_factory): nterms = 30 try: - from hellskitchen._native import jfuns2d, hank103_vec + from pyfmmlib import jfuns2d, hank103_vec except ImportError: - use_hellskitchen = False + use_pyfmmlib = False else: - use_hellskitchen = True + use_pyfmmlib = True - if use_hellskitchen: + print "PYFMMLIB", use_pyfmmlib + + if use_pyfmmlib: a = np.logspace(-3, 3, 10**6) else: a = np.logspace(-5, 5, 10**6) for which_func, cl_func, scipy_func, is_rel in [ - #("j", clmath.bessel_jn, spec.jn, False), + ("j", clmath.bessel_jn, spec.jn, False), ("y", clmath.bessel_yn, spec.yn, True) ]: if is_rel: @@ -242,19 +244,19 @@ def test_bessel(ctx_factory): def get_err(check, ref): return np.max(np.abs(check-ref)) - if use_hellskitchen: - hellskitchen_result = np.empty((len(a), nterms), dtype=np.complex128) + if use_pyfmmlib: + pfymm_result = np.empty((len(a), nterms), dtype=np.complex128) if which_func == "j": for i, a_i in enumerate(a): - if i % 10000 == 0: + if i % 100000 == 0: print("%.1f %%" % (100 * i/len(a))) ier, fjs, _, _ = jfuns2d(nterms, a_i, 1, 0, 10000) - hellskitchen_result[i] = fjs[:nterms] + pfymm_result[i] = fjs[:nterms] assert ier == 0 elif which_func == "y": h0, h1 = hank103_vec(a, ifexpon=1) - hellskitchen_result[:, 0] = h0.imag - hellskitchen_result[:, 1] = h1.imag + pfymm_result[:, 0] = h0.imag + pfymm_result[:, 1] = h1.imag a_dev = cl_array.to_device(queue, a) @@ -265,17 +267,17 @@ def test_bessel(ctx_factory): error_scipy = get_err(cl_bessel, scipy_bessel) assert error_scipy < 1e-10, error_scipy - if use_hellskitchen and ( + if use_pyfmmlib and ( which_func == "j" or (which_func == "y" and n in [0, 1])): - hk_bessel = hellskitchen_result[:, n] - error_hk = get_err(cl_bessel, hk_bessel) - assert error_hk < 1e-10, error_hk - error_hk_scipy = get_err(scipy_bessel, hk_bessel) - print(n, error_scipy, error_hk, error_hk_scipy) + pyfmm_bessel = pfymm_result[:, n] + error_pyfmm = get_err(cl_bessel, pyfmm_bessel) + assert error_pyfmm < 1e-10, error_pyfmm + error_pyfmm_scipy = get_err(scipy_bessel, pyfmm_bessel) + print(which_func, n, error_scipy, error_pyfmm, error_pyfmm_scipy) else: - print(n, error_scipy) + print(which_func, n, error_scipy) assert not np.isnan(cl_bessel).any() @@ -285,8 +287,8 @@ def test_bessel(ctx_factory): #pt.plot(cl_bessel) pt.loglog(a, np.abs(cl_bessel-scipy_bessel), label="vs scipy") - if use_hellskitchen: - pt.loglog(a, np.abs(cl_bessel-hk_bessel), label="vs hellskitchen") + if use_pyfmmlib: + pt.loglog(a, np.abs(cl_bessel-hk_bessel), label="vs pyfmmlib") pt.legend() pt.show() @@ -294,7 +296,6 @@ def test_bessel(ctx_factory): if __name__ == "__main__": - # make sure that import failures get reported, instead of skipping the tests. import sys if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab