From bec10d9fa337fb13501b57238504c793c693d68a Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 7 Sep 2021 09:23:56 -0500
Subject: [PATCH] add derivatives to TargetPointMultiplier test

---
 test/test_fmm.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/test/test_fmm.py b/test/test_fmm.py
index 73b80ef0..69be8a2f 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -513,7 +513,8 @@ def test_sumpy_axis_source_derivative(ctx_factory):
     assert np.isclose(rel_err, 0, atol=1e-5)
 
 
-def test_sumpy_target_point_multiplier(ctx_factory):
+@pytest.mark.parametrize("deriv_axes", [(), (0,), (1,)])
+def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes):
     logging.basicConfig(level=logging.INFO)
 
     ctx = ctx_factory()
@@ -552,9 +553,12 @@ def test_sumpy_target_point_multiplier(ctx_factory):
     from functools import partial
 
     from sumpy.fmm import SumpyExpansionWranglerCodeContainer
-    from sumpy.kernel import TargetPointMultiplier
+    from sumpy.kernel import TargetPointMultiplier, AxisTargetDerivative
 
-    tgt_knls = [TargetPointMultiplier(0, knl), knl]
+    tgt_knls = [TargetPointMultiplier(0, knl), knl, knl]
+    for axis in deriv_axes:
+        tgt_knls[0] = AxisTargetDerivative(axis, tgt_knls[0])
+        tgt_knls[1] = AxisTargetDerivative(axis, tgt_knls[1])
 
     wcc = SumpyExpansionWranglerCodeContainer(
             ctx,
@@ -571,11 +575,14 @@ def test_sumpy_target_point_multiplier(ctx_factory):
 
     from boxtree.fmm import drive_fmm
 
-    pot0, pot1 = drive_fmm(trav, wrangler, (weights,))
-    pot0, pot1 = pot0.get(), pot1.get()
-    pot1 = pot1 * sources[0].get()
+    pot0, pot1, pot2 = drive_fmm(trav, wrangler, (weights,))
+    pot0, pot1, pot2 = pot0.get(), pot1.get(), pot2.get()
+    if deriv_axes == (0,):
+        ref_pot = pot1 * sources[0].get() + pot2
+    else:
+        ref_pot = pot1 * sources[0].get()
 
-    rel_err = la.norm(pot0 - pot1) / la.norm(pot1)
+    rel_err = la.norm(pot0 - ref_pot) / la.norm(ref_pot)
     logger.info("order %d -> relative l2 error: %g" % (order, rel_err))
 
     assert np.isclose(rel_err, 0, atol=1e-5)
-- 
GitLab