From 70690376d5c0ff9f59beca2c5b201e85b081d791 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 22 Oct 2018 19:45:35 -0500
Subject: [PATCH] Various extensions to the calculus patch

---
 sumpy/point_calculus.py | 68 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 68 insertions(+)

diff --git a/sumpy/point_calculus.py b/sumpy/point_calculus.py
index 9321f7d9..f8b4f5bd 100644
--- a/sumpy/point_calculus.py
+++ b/sumpy/point_calculus.py
@@ -43,6 +43,8 @@ class CalculusPatch(object):
 
         shape: ``(dim, npoints_total)``
 
+    .. automethod:: weights
+    .. automethod:: basis
     .. automethod:: diff
     .. automethod:: dx
     .. automethod:: dy
@@ -55,6 +57,8 @@ class CalculusPatch(object):
     .. autoattribute:: y
     .. autoattribute:: z
     .. automethod:: norm
+    .. automethod:: plot_nodes
+    .. automethod:: plot
     """
     def __init__(self, center, h=1e-1, order=4, nodes="chebyshev"):
         self.center = center
@@ -62,15 +66,26 @@ class CalculusPatch(object):
         npoints = order + 1
         if nodes == "equispaced":
             points_1d = np.linspace(-h/2, h/2, npoints)
+            weights_1d = None
 
         elif nodes == "chebyshev":
             a = np.arange(npoints, dtype=np.float64)
             points_1d = (h/2)*np.cos((2*(a+1)-1)/(2*npoints)*np.pi)
+            weights_1d = None
+
+        elif nodes == "legendre":
+            from scipy.special import legendre
+            points_1d, weights_1d, _ = legendre(npoints).weights.T
+            points_1d = points_1d * (h/2)
+            weights_1d = weights_1d * (h/2)
 
         else:
             raise ValueError("invalid node set: %s" % nodes)
 
+        self.h = h
+        self.npoints = npoints
         self._points_1d = points_1d
+        self._weights_1d = weights_1d
 
         self.dim = dim = len(self.center)
         self.center = center
@@ -100,6 +115,44 @@ class CalculusPatch(object):
         # The zeroth coefficient--all others involve x=0.
         return self._vandermonde_1d()[0]
 
+    def basis(self):
+        """"
+        :returns: a :class:`list` containing functions that realize
+            a high-order interpolation basis on the :attr:`points`.
+        """
+
+        from pytools import indices_in_shape
+        from scipy.special import eval_chebyt
+
+        def eval_basis(ind, x):
+            result = 1
+            for i in range(self.dim):
+                coord = (x[i] - self.center[i])/(self.h/2)
+                result *= eval_chebyt(ind[i], coord)
+            return result
+
+        from functools import partial
+        return [
+                partial(eval_basis, ind)
+                for ind in indices_in_shape((self.npoints,)*self.dim)]
+
+    @memoize_method
+    def weights(self):
+        """"
+        :returns: a vector of high-order quadrature weights on the :attr:`points`
+        """
+
+        if self._weights_1d is None:
+            raise NotImplementedError("weights not available for these nodes")
+
+        result = np.ones_like(self._points_shaped[0])
+        for i in range(self.dim):
+            result = result * self._weights_1d[
+                    (slice(None),)
+                    + i*(np.newaxis,)]
+
+        return result.reshape(-1)
+
     @memoize_method
     def _diff_mat_1d(self, nderivs):
         npoints = len(self._points_1d)
@@ -217,6 +270,21 @@ class CalculusPatch(object):
         else:
             raise ValueError("unsupported norm")
 
+    def plot_nodes(self):
+        import matplotlib.pyplot as plt
+        plt.gca().set_aspect("equal")
+        plt.plot(
+            self._points_shaped[0].reshape(-1),
+            self._points_shaped[1].reshape(-1),
+            "o")
+
+    def plot(self, f):
+        f = f.reshape(*self._pshape)
+
+        import matplotlib.pyplot as plt
+        plt.gca().set_aspect("equal")
+        plt.contourf(self._points_1d, self._points_1d, f)
+
 
 def frequency_domain_maxwell(cpatch, e, h, k):
     mu = 1
-- 
GitLab