From 6189de9d76e779b0590a14d8c038652cfef0ffe8 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Thu, 7 May 2020 18:24:57 -0500
Subject: [PATCH] Use the sac

---
 sumpy/expansion/__init__.py  | 10 ++++++----
 sumpy/expansion/local.py     |  4 ++--
 sumpy/expansion/multipole.py |  3 +--
 sumpy/tools.py               | 35 +++++++++++++++++++++++------------
 test/test_kernels.py         |  2 +-
 5 files changed, 33 insertions(+), 21 deletions(-)

diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index 5551537a..71af7eba 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -161,12 +161,13 @@ class ExpansionBase(object):
     def __ne__(self, other):
         return not self.__eq__(other)
 
-    def get_kernel_derivative_taker(self, dvec, rscale=1):
+    def get_kernel_derivative_taker(self, dvec, rscale, sac):
         """Return a MiDerivativeTaker instance that supports taking
         derivatives of the kernel with respect to dvec
         """
         from sumpy.tools import MiDerivativeTaker
-        return MiDerivativeTaker(self.kernel.get_expression(dvec), dvec, rscale)
+        return MiDerivativeTaker(self.kernel.get_expression(dvec), dvec, rscale,
+                sac)
 
 # }}}
 
@@ -648,9 +649,10 @@ class LaplaceConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
     def __init__(self, kernel, order, use_rscale):
         self.expansion_terms_wrangler_key = (order, kernel.dim)
 
-    def get_kernel_derivative_taker(self, dvec, rscale=1):
+    def get_kernel_derivative_taker(self, dvec, rscale, sac):
         from sumpy.tools import LaplaceDerivativeTaker
-        return LaplaceDerivativeTaker(self.kernel.get_expression(dvec), dvec, rscale)
+        return LaplaceDerivativeTaker(self.kernel.get_expression(dvec), dvec,
+                rscale, sac)
 
 
 class HelmholtzConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py
index 032516ef..c2947c69 100644
--- a/sumpy/expansion/local.py
+++ b/sumpy/expansion/local.py
@@ -124,7 +124,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
         from sumpy.tools import MiDerivativeTakerWrapper
 
         result = []
-        taker = self.get_kernel_derivative_taker(avec, rscale)
+        taker = self.get_kernel_derivative_taker(avec, rscale, sac)
         expr_dict = {(0,)*self.dim: 1}
         expr_dict = self.kernel.get_derivative_transformation_at_source(expr_dict)
         pp_nderivatives = single_valued(sum(mi) for mi in expr_dict.keys())
@@ -215,7 +215,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
 
             # The vector has the kernel derivatives and depends only on the distance
             # between the two centers
-            taker = src_expansion.get_kernel_derivative_taker(dvec, src_rscale)
+            taker = src_expansion.get_kernel_derivative_taker(dvec, src_rscale, sac)
             vector_stored = []
             # Calculate the kernel derivatives for the compressed set
             for term in \
diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py
index de9aa500..fe8d30a1 100644
--- a/sumpy/expansion/multipole.py
+++ b/sumpy/expansion/multipole.py
@@ -25,7 +25,6 @@ THE SOFTWARE.
 from six.moves import range, zip
 import sumpy.symbolic as sym  # noqa
 
-from sumpy.symbolic import vector_xreplace
 from sumpy.expansion import (
     ExpansionBase, VolumeTaylorExpansion, LaplaceConformingVolumeTaylorExpansion,
     HelmholtzConformingVolumeTaylorExpansion,
@@ -88,7 +87,7 @@ class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase):
         if knl is None:
             knl = self.kernel
 
-        taker = self.get_kernel_derivative_taker(bvec, rscale)
+        taker = self.get_kernel_derivative_taker(bvec, rscale, sac)
         expr_dict = {(0,)*self.dim: 1}
         expr_dict = knl.get_derivative_transformation_at_target(expr_dict)
         pp_nderivatives = single_valued(sum(mi) for mi in expr_dict.keys())
diff --git a/sumpy/tools.py b/sumpy/tools.py
index 4cc529ed..993ae91e 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -104,7 +104,7 @@ def mi_power(vector, mi, evaluate=True):
 
 class MiDerivativeTaker(object):
 
-    def __init__(self, expr, var_list, rscale=1):
+    def __init__(self, expr, var_list, rscale=1, sac=None):
         r"""
         A class to take scaled derivatives of the symbolic expression
         expr w.r.t. variables var_list and the scaling parameter rscale.
@@ -175,6 +175,7 @@ class MiDerivativeTaker(object):
         empty_mi = (0,) * len(var_list)
         self.cache_by_mi = {empty_mi: expr}
         self.rscale = rscale
+        self.sac = sac
 
     def mi_dist(self, a, b):
         return np.array(a, dtype=int) - np.array(b, dtype=int)
@@ -207,13 +208,21 @@ class MiDerivativeTaker(object):
                 if (np.array(mi) >= np.array(other_mi)).all()),
             key=lambda other_mi: sum(self.mi_dist(mi, other_mi)))
 
+    def add_to_sac(self, expr):
+        import sumpy.symbolic as sym
+        if self.sac is not None:
+            return sym.Symbol(self.sac.assign_unique("temp", expr))
+        else:
+            return expr
+
 
 class LaplaceDerivativeTaker(MiDerivativeTaker):
 
-    def __init__(self, expr, var_list, rscale=1):
-        super(LaplaceDerivativeTaker, self).__init__(expr, var_list, rscale)
-        self.r = sym.sqrt(sum(v**2 for v in var_list))
-        self.scaled_r = sym.sqrt(sum((v/rscale)**2 for v in var_list))
+    def __init__(self, expr, var_list, rscale=1, sac=None):
+        super(LaplaceDerivativeTaker, self).__init__(expr, var_list, rscale, sac)
+        self.scaled_var_list = [self.add_to_sac(v/rscale) for v in var_list]
+        self.scaled_r = self.add_to_sac(
+                sym.sqrt(sum(v**2 for v in self.scaled_var_list)))
 
     def diff(self, mi):
         # Return zero for negative values. Makes the algorithm readable.
@@ -235,22 +244,24 @@ class LaplaceDerivativeTaker(MiDerivativeTaker):
             for i in range(dim):
                 mi_minus_one = list(mi)
                 mi_minus_one[i] -= 1
+                mi_minus_one = tuple(mi_minus_one)
                 mi_minus_two = list(mi)
                 mi_minus_two[i] -= 2
-                x = self.var_list[i]
+                mi_minus_two = tuple(mi_minus_two)
+                x = self.scaled_var_list[i]
                 n = mi[i]
                 if i == d:
                     if dim == 3:
-                        expr -= (2*n-1)*(x/self.rscale)*self.diff(tuple(mi_minus_one))
-                        expr -= (n-1)**2*self.diff(tuple(mi_minus_two))
+                        expr -= (2*n - 1) * x * self.diff(mi_minus_one)
+                        expr -= (n - 1)**2 * self.diff(mi_minus_two)
                     else:
-                        expr -= 2*(x/self.rscale)*(n-1)*self.diff(tuple(mi_minus_one))
-                        expr -= (n-1)*(n-2)*self.diff(tuple(mi_minus_two))
+                        expr -= 2 * x * (n - 1) * self.diff(mi_minus_one)
+                        expr -= (n - 1) * (n - 2) * self.diff(mi_minus_two)
                         if n == 2 and sum(mi) == 2:
                             expr += 1
                 else:
-                    expr -= 2*n*(x/self.rscale)*self.diff(tuple(mi_minus_one))
-                    expr -= n*(n-1)*self.diff(tuple(mi_minus_two))
+                    expr -= 2 * n * x * self.diff(mi_minus_one)
+                    expr -= n * (n - 1) * self.diff(mi_minus_two)
             expr /= self.scaled_r**2
             self.cache_by_mi[mi] = expr
         return expr
diff --git a/test/test_kernels.py b/test/test_kernels.py
index f37e51c6..4e445888 100644
--- a/test/test_kernels.py
+++ b/test/test_kernels.py
@@ -366,7 +366,7 @@ def _m2l_translate_simple(tgt_expansion, src_expansion, src_coeff_exprs, src_rsc
     #
     # To get the local expansion coefficients, we take derivatives of
     # the multipole expansion.
-    taker = src_expansion.get_kernel_derivative_taker(dvec, src_rscale)
+    taker = src_expansion.get_kernel_derivative_taker(dvec, src_rscale, sac=None)
 
     from sumpy.tools import add_mi
 
-- 
GitLab