From baaf7ef74c0b74c4b70fd7202aa78280a8fd32ef Mon Sep 17 00:00:00 2001 From: Isuru Fernando <isuruf@gmail.com> Date: Mon, 29 Jan 2018 15:04:54 -0600 Subject: [PATCH] Remove rscale from try_get_recurrence_for_derivative --- sumpy/expansion/__init__.py | 35 ++++++++++++++++++++--------------- sumpy/tools.py | 2 +- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index 5432a54e..15a4c447 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -246,7 +246,6 @@ def _spmv(spmat, x, sparse_vectors): class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): - _rscale_symbol = sp.Symbol("_sumpy_rscale_placeholder") def get_coefficient_identifiers(self): return self.stored_identifiers @@ -295,15 +294,21 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): """ stored_identifiers, coeff_matrix = self._get_stored_ids_and_coeff_mat() - # substitute actual rscale for internal placeholder - return defaultdict(lambda: [], - ((irow, [ - (icol, - coeff.xreplace({self._rscale_symbol: rscale}) - if isinstance(coeff, sp.Basic) - else coeff) - for icol, coeff in row]) - for irow, row in six.iteritems(coeff_matrix))) + full_coeffs = self.get_full_coefficient_identifiers() + matrix_rows = [] + for irow, row in six.iteritems(coeff_matrix): + # For eg: (u_xxx / rscale**3) = (u_yy / rscale**2) * coeff1 + + # (u_xx / rscale**2) * coeff2 + # is converted to u_xxx = u_yy * (rscale * coeff1) + + # u_xx * (rscale * coeff2) + row_rscale = sum(full_coeffs[irow]) + matrix_row = [] + for icol, coeff in row: + diff = row_rscale - sum(stored_identifiers[icol]) + matrix_row.append((icol, coeff * rscale**diff)) + matrix_rows.append((irow, matrix_row)) + + return defaultdict(lambda: [], matrix_rows) @memoize_method def _get_stored_ids_and_coeff_mat(self): @@ -321,7 +326,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): from six import iteritems for i, identifier in enumerate(self.get_full_coefficient_identifiers()): expr = self.try_get_recurrence_for_derivative( - identifier, identifiers_so_far, rscale=self._rscale_symbol) + identifier, identifiers_so_far) if expr is None: # Identifier should be stored @@ -353,7 +358,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): return stored_identifiers, coeff_matrix - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): """ :arg deriv: a tuple of integers identifying a derivative for which a recurrence is sought @@ -370,7 +375,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): class LaplaceDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): deriv = np.array(deriv, dtype=int) for dim in np.where(2 <= deriv)[0]: @@ -402,7 +407,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): super(HelmholtzDerivativeWrangler, self).__init__(order, dim) self.helmholtz_k_name = helmholtz_k_name - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): deriv = np.array(deriv, dtype=int) for dim in np.where(2 <= deriv)[0]: @@ -426,7 +431,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): coeffs[needed_deriv] = -1 else: k = sym.Symbol(self.helmholtz_k_name) - coeffs[tuple(reduced_deriv)] = -k*k*rscale*rscale + coeffs[tuple(reduced_deriv)] = -k*k return coeffs # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 46fa9ee9..964c64cc 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -126,7 +126,7 @@ class LinearRecurrenceBasedMiDerivativeTaker(MiDerivativeTaker): recurrence = ( self.wrangler.try_get_recurrence_for_derivative( - next_mi, self.cache_by_mi, rscale=1)) + next_mi, self.cache_by_mi)) if recurrence is not None: expr = Add(*tuple( -- GitLab