diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index aa09ddf5f59473d7eecb1230b20c9bc0b2504f2d..d1ad064476ffe585c40e625bae2bd857bb0ed09b 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -250,15 +250,13 @@ class CSEMatVec(object): def __init__(self, assignments): self.assignments = assignments - def matvec(self, vec, sac): + def matvec(self, vec, wrap_intermediate=lambda x: x): """ :arg vec: vector for the matrix vector multiplication - :arg sac: an object of type - :class:`sumpy.assignment_collection.SymbolicAssignmentCollection` - for storing intermediate row values. If given `None`, the - matvec operation will not use common subexpression elimination - resulting in an expensive matvec + :arg wrap_intermediate: a function to wrap intermediate expressions + If not given, the number of operations might grow in the + final expressions in the vector resulting in an expensive matvec. """ res = [0] * len(self.assignments) stored_idx = 0 @@ -269,26 +267,21 @@ class CSEMatVec(object): else: for k, v in deps: res[row] += res[k] * v - if sac is not None: - new_sym = sym.Symbol(sac.assign_unique("projection_temp", res[row])) - res[row] = new_sym + res[row] = wrap_intermediate(res[row]) return res - def transpose_matvec(self, vec, sac): + def transpose_matvec(self, vec, wrap_intermediate=lambda x: x): res = [] expr_all = list(vec) for row, deps in reversed(list(enumerate(self.assignments))): if len(deps) == 0: res.append(expr_all[row]) continue - if sac is not None: - new_sym = \ - sym.Symbol(sac.assign_unique("compress_temp", expr_all[row])) - for k, v in deps: - expr_all[k] += new_sym * v - else: - for k, v in deps: - expr_all[k] += expr_all[row] * v + + new_sym = wrap_intermediate(expr_all[row]) + for k, v in deps: + expr_all[k] += new_sym * v + res.reverse() return res @@ -316,14 +309,25 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler): def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives, rscale, sac=None): + + def wrap(expr): + if sac is None: + return expr + return sym.Symbol(sac.assign_unique("projection_temp", expr)) + projection_matrix = self.get_projection_matrix(rscale) - return projection_matrix.matvec(stored_kernel_derivatives, sac) + return projection_matrix.matvec(stored_kernel_derivatives, wrap) def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients, rscale, sac=None): - # = M^T x, where M = projection matrix + + def wrap(expr): + if sac is None: + return expr + return sym.Symbol(sac.assign_unique("compress_temp", expr)) + projection_matrix = self.get_projection_matrix(rscale) - return projection_matrix.transpose_matvec(full_mpole_coefficients, sac) + return projection_matrix.transpose_matvec(full_mpole_coefficients, wrap) @property def stored_identifiers(self):