From bbafb8f01abce12f65785be07ff36b4e009cc16b Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Wed, 8 Jul 2020 18:40:16 -0500
Subject: [PATCH] Use wrap_intermediate callback function instead of sac

---
 sumpy/expansion/__init__.py | 46 ++++++++++++++++++++-----------------
 1 file changed, 25 insertions(+), 21 deletions(-)

diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index aa09ddf5..d1ad0644 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -250,15 +250,13 @@ class CSEMatVec(object):
     def __init__(self, assignments):
         self.assignments = assignments
 
-    def matvec(self, vec, sac):
+    def matvec(self, vec, wrap_intermediate=lambda x: x):
         """
         :arg vec: vector for the matrix vector multiplication
 
-        :arg sac: an object of type
-                  :class:`sumpy.assignment_collection.SymbolicAssignmentCollection`
-                  for storing intermediate row values. If given `None`, the
-                  matvec operation will not use common subexpression elimination
-                  resulting in an expensive matvec
+        :arg wrap_intermediate: a function to wrap intermediate expressions
+             If not given, the number of operations might grow in the
+             final expressions in the vector resulting in an expensive matvec.
         """
         res = [0] * len(self.assignments)
         stored_idx = 0
@@ -269,26 +267,21 @@ class CSEMatVec(object):
             else:
                 for k, v in deps:
                     res[row] += res[k] * v
-            if sac is not None:
-                new_sym = sym.Symbol(sac.assign_unique("projection_temp", res[row]))
-                res[row] = new_sym
+            res[row] = wrap_intermediate(res[row])
         return res
 
-    def transpose_matvec(self, vec, sac):
+    def transpose_matvec(self, vec, wrap_intermediate=lambda x: x):
         res = []
         expr_all = list(vec)
         for row, deps in reversed(list(enumerate(self.assignments))):
             if len(deps) == 0:
                 res.append(expr_all[row])
                 continue
-            if sac is not None:
-                new_sym = \
-                    sym.Symbol(sac.assign_unique("compress_temp", expr_all[row]))
-                for k, v in deps:
-                    expr_all[k] += new_sym * v
-            else:
-                for k, v in deps:
-                    expr_all[k] += expr_all[row] * v
+
+            new_sym = wrap_intermediate(expr_all[row])
+            for k, v in deps:
+                expr_all[k] += new_sym * v
+
         res.reverse()
         return res
 
@@ -316,14 +309,25 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
 
     def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives,
             rscale, sac=None):
+
+        def wrap(expr):
+            if sac is None:
+                return expr
+            return sym.Symbol(sac.assign_unique("projection_temp", expr))
+
         projection_matrix = self.get_projection_matrix(rscale)
-        return projection_matrix.matvec(stored_kernel_derivatives, sac)
+        return projection_matrix.matvec(stored_kernel_derivatives, wrap)
 
     def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients,
             rscale, sac=None):
-        # = M^T x, where M = projection matrix
+
+        def wrap(expr):
+            if sac is None:
+                return expr
+            return sym.Symbol(sac.assign_unique("compress_temp", expr))
+
         projection_matrix = self.get_projection_matrix(rscale)
-        return projection_matrix.transpose_matvec(full_mpole_coefficients, sac)
+        return projection_matrix.transpose_matvec(full_mpole_coefficients, wrap)
 
     @property
     def stored_identifiers(self):
-- 
GitLab