From bcf5310de4c96d80e57263439eaa6fe11890a5db Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Thu, 7 May 2020 22:17:27 -0500
Subject: [PATCH] Use the sac more

---
 sumpy/expansion/local.py | 12 ++++++++----
 sumpy/tools.py           | 25 +++++++++++++------------
 2 files changed, 21 insertions(+), 16 deletions(-)

diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py
index c2947c69..f6cce8fe 100644
--- a/sumpy/expansion/local.py
+++ b/sumpy/expansion/local.py
@@ -32,7 +32,7 @@ from sumpy.expansion import (
     BiharmonicConformingVolumeTaylorExpansion)
 
 from sumpy.tools import (matvec_toeplitz_upper_triangular,
-    fft_toeplitz_upper_triangular)
+    fft_toeplitz_upper_triangular, add_to_sac)
 
 
 class LocalExpansionBase(ExpansionBase):
@@ -238,14 +238,16 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             vector = [0]*len(toeplitz_matrix_coeffs)
             for i, term in enumerate(toeplitz_matrix_coeffs):
                 if term in srcplusderiv_ident_to_index:
-                    vector[i] = vector_full[srcplusderiv_ident_to_index[term]]
+                    vector[i] = add_to_sac(sac,
+                            vector_full[srcplusderiv_ident_to_index[term]])
 
             # Calculate the first row of the upper triangular Toeplitz matrix
             toeplitz_first_row = [0] * len(toeplitz_matrix_coeffs)
             for coeff, term in zip(
                     src_coeff_exprs,
                     src_expansion.get_coefficient_identifiers()):
-                toeplitz_first_row[toeplitz_matrix_ident_to_index[term]] = coeff
+                toeplitz_first_row[toeplitz_matrix_ident_to_index[term]] = \
+                        add_to_sac(sac, coeff)
 
             # Do the matvec
             if use_fft:
@@ -263,7 +265,9 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             logger.info("building translation operator: done")
             return result
 
-        rscale_ratio = sym.UnevaluatedExpr(tgt_rscale/src_rscale)
+        rscale_ratio = tgt_rscale/src_rscale
+        if sac is not None:
+            rscale_ratio = sym.Symbol(sac.assign_unique("temp"), rscale_ratio)
 
         from sumpy.tools import MiDerivativeTaker
         from math import factorial
diff --git a/sumpy/tools.py b/sumpy/tools.py
index 993ae91e..26b3f62e 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -102,6 +102,14 @@ def mi_power(vector, mi, evaluate=True):
     return result
 
 
+def add_to_sac(sac, expr):
+    import sumpy.symbolic as sym
+    if sac is not None:
+        return sym.Symbol(sac.assign_unique("temp", expr))
+    else:
+        return expr
+
+
 class MiDerivativeTaker(object):
 
     def __init__(self, expr, var_list, rscale=1, sac=None):
@@ -208,20 +216,13 @@ class MiDerivativeTaker(object):
                 if (np.array(mi) >= np.array(other_mi)).all()),
             key=lambda other_mi: sum(self.mi_dist(mi, other_mi)))
 
-    def add_to_sac(self, expr):
-        import sumpy.symbolic as sym
-        if self.sac is not None:
-            return sym.Symbol(self.sac.assign_unique("temp", expr))
-        else:
-            return expr
-
 
 class LaplaceDerivativeTaker(MiDerivativeTaker):
 
     def __init__(self, expr, var_list, rscale=1, sac=None):
         super(LaplaceDerivativeTaker, self).__init__(expr, var_list, rscale, sac)
-        self.scaled_var_list = [self.add_to_sac(v/rscale) for v in var_list]
-        self.scaled_r = self.add_to_sac(
+        self.scaled_var_list = [add_to_sac(self.sac, v/rscale) for v in var_list]
+        self.scaled_r = add_to_sac(self.sac,
                 sym.sqrt(sum(v**2 for v in self.scaled_var_list)))
 
     def diff(self, mi):
@@ -263,7 +264,7 @@ class LaplaceDerivativeTaker(MiDerivativeTaker):
                     expr -= 2 * n * x * self.diff(mi_minus_one)
                     expr -= n * (n - 1) * self.diff(mi_minus_two)
             expr /= self.scaled_r**2
-            self.cache_by_mi[mi] = expr
+            self.cache_by_mi[mi] = add_to_sac(self.sac, expr)
         return expr
 
 
@@ -1034,8 +1035,8 @@ def matvec_toeplitz_upper_triangular(first_row, vector):
     assert len(vector) == n
     output = [0]*n
     for row in range(n):
-        for col in range(row, n):
-            output[row] += first_row[col-row]*vector[col]
+        terms = tuple(first_row[col-row]*vector[col] for col in range(row, n))
+        output[row] = sym.Add(*terms)
     return output
 
 # }}}
-- 
GitLab