From ba0bc497ed24d37ac5855bb68fb468dbfe264d6b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 13 Nov 2017 12:22:26 -0600
Subject: [PATCH] Toys: Add combine_halfspace_and_outer

---
 sumpy/toys.py | 59 +++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 50 insertions(+), 9 deletions(-)

diff --git a/sumpy/toys.py b/sumpy/toys.py
index 55b33137..2804b819 100644
--- a/sumpy/toys.py
+++ b/sumpy/toys.py
@@ -47,7 +47,6 @@ local and multipole expansions.
 .. autoclass:: ToyContext
 .. autoclass:: PotentialSource
 .. autoclass:: ConstantPotential
-.. autoclass:: OneOnBallPotential
 .. autoclass:: PointSources
 
 These functions manipulate these pootentials:
@@ -55,8 +54,9 @@ These functions manipulate these pootentials:
 .. autofunction:: multipole_expand
 .. autofunction:: local_expand
 .. autofunction:: logplot
-.. autofunction:: restrict_inner
-.. autofunction:: restrict_outer
+.. autofunction:: combine_inner_outer
+.. autofunction:: combine_halfspace
+.. autofunction:: combine_halfspace_and_outer
 
 These fucntions help with plotting:
 
@@ -68,6 +68,8 @@ These fucntions help with plotting:
 These are created behind the scenes and are not typically directly instantiated
 by users:
 
+.. autoclass:: OneOnBallPotential
+.. autoclass:: HalfspaceOnePotential
 .. autoclass:: ExpansionPotentialSource
 .. autoclass:: MultipoleExpansion
 .. autoclass:: LocalExpansion
@@ -370,6 +372,22 @@ class OneOnBallPotential(PotentialSource):
         return (np.sum(dist_vec**2, axis=0) < self.radius**2).astype(np.float64)
 
 
+class HalfspaceOnePotential(PotentialSource):
+    """
+    .. automethod:: __init__
+    """
+    def __init__(self, toy_ctx, center, axis, side=1):
+        super(HalfspaceOnePotential, self).__init__(toy_ctx)
+        self.center = np.asarray(center)
+        self.axis = axis
+        self.side = side
+
+    def eval(self, targets):
+        return (
+            (self.side*(targets[self.axis] - self.center[self.axis])) >= 0
+            ).astype(np.float64)
+
+
 class PointSources(PotentialSource):
     """
     .. attribute:: points
@@ -532,18 +550,41 @@ def logplot(fp, psource, **kwargs):
             np.log10(np.abs(psource.eval(fp.points) + 1e-15)), **kwargs)
 
 
-def restrict_inner(psource, radius, center=None):
+def combine_inner_outer(psource_inner, psource_outer, radius, center=None):
     if center is None:
-        center = psource.center
+        center = psource_inner.center
+    if radius is None:
+        radius = psource_inner.radius
 
-    return psource * OneOnBallPotential(psource.toy_ctx, center, radius)
+    ball_one = OneOnBallPotential(psource_inner.toy_ctx, center, radius)
+    return (
+            psource_inner * ball_one
+            +
+            psource_outer * (1 - ball_one))
 
 
-def restrict_outer(psource, radius, center=None):
+def combine_halfspace(psource_pos, psource_neg, axis, center=None):
     if center is None:
-        center = psource.center
+        center = psource_pos.center
+
+    halfspace_one = HalfspaceOnePotential(psource_pos.toy_ctx, center, axis)
+    return (
+        psource_pos * halfspace_one
+        +
+        psource_neg * (1-halfspace_one))
+
+
+def combine_halfspace_and_outer(psource_pos, psource_neg, psource_outer,
+        axis, radius=None, center=None):
+
+    if center is None:
+        center = psource_pos.center
+    if radius is None:
+        center = psource_pos.radius
 
-    return psource * (1-OneOnBallPotential(psource.toy_ctx, center, radius))
+    return combine_inner_outer(
+            combine_halfspace(psource_pos, psource_neg, axis, center),
+            psource_outer, radius, center)
 
 
 def l_inf(psource, radius, center=None, npoints=100, debug=False):
-- 
GitLab