From bec10d9fa337fb13501b57238504c793c693d68a Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Tue, 7 Sep 2021 09:23:56 -0500 Subject: [PATCH] add derivatives to TargetPointMultiplier test --- test/test_fmm.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/test/test_fmm.py b/test/test_fmm.py index 73b80ef0..69be8a2f 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -513,7 +513,8 @@ def test_sumpy_axis_source_derivative(ctx_factory): assert np.isclose(rel_err, 0, atol=1e-5) -def test_sumpy_target_point_multiplier(ctx_factory): +@pytest.mark.parametrize("deriv_axes", [(), (0,), (1,)]) +def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes): logging.basicConfig(level=logging.INFO) ctx = ctx_factory() @@ -552,9 +553,12 @@ def test_sumpy_target_point_multiplier(ctx_factory): from functools import partial from sumpy.fmm import SumpyExpansionWranglerCodeContainer - from sumpy.kernel import TargetPointMultiplier + from sumpy.kernel import TargetPointMultiplier, AxisTargetDerivative - tgt_knls = [TargetPointMultiplier(0, knl), knl] + tgt_knls = [TargetPointMultiplier(0, knl), knl, knl] + for axis in deriv_axes: + tgt_knls[0] = AxisTargetDerivative(axis, tgt_knls[0]) + tgt_knls[1] = AxisTargetDerivative(axis, tgt_knls[1]) wcc = SumpyExpansionWranglerCodeContainer( ctx, @@ -571,11 +575,14 @@ def test_sumpy_target_point_multiplier(ctx_factory): from boxtree.fmm import drive_fmm - pot0, pot1 = drive_fmm(trav, wrangler, (weights,)) - pot0, pot1 = pot0.get(), pot1.get() - pot1 = pot1 * sources[0].get() + pot0, pot1, pot2 = drive_fmm(trav, wrangler, (weights,)) + pot0, pot1, pot2 = pot0.get(), pot1.get(), pot2.get() + if deriv_axes == (0,): + ref_pot = pot1 * sources[0].get() + pot2 + else: + ref_pot = pot1 * sources[0].get() - rel_err = la.norm(pot0 - pot1) / la.norm(pot1) + rel_err = la.norm(pot0 - ref_pot) / la.norm(ref_pot) logger.info("order %d -> relative l2 error: %g" % (order, rel_err)) assert np.isclose(rel_err, 0, atol=1e-5) -- GitLab