diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index 5432a54e22a1400efbb31118c6328b85fd31d980..4a47ffda036db262d21ec8c0a55f945f76cebe68 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -29,7 +29,6 @@ import logging from pytools import memoize_method import sumpy.symbolic as sym from sumpy.tools import MiDerivativeTaker -import sumpy.symbolic as sp from collections import defaultdict @@ -246,7 +245,6 @@ def _spmv(spmat, x, sparse_vectors): class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): - _rscale_symbol = sp.Symbol("_sumpy_rscale_placeholder") def get_coefficient_identifiers(self): return self.stored_identifiers @@ -295,15 +293,21 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): """ stored_identifiers, coeff_matrix = self._get_stored_ids_and_coeff_mat() - # substitute actual rscale for internal placeholder - return defaultdict(lambda: [], - ((irow, [ - (icol, - coeff.xreplace({self._rscale_symbol: rscale}) - if isinstance(coeff, sp.Basic) - else coeff) - for icol, coeff in row]) - for irow, row in six.iteritems(coeff_matrix))) + full_coeffs = self.get_full_coefficient_identifiers() + matrix_rows = [] + for irow, row in six.iteritems(coeff_matrix): + # For eg: (u_xxx / rscale**3) = (u_yy / rscale**2) * coeff1 + + # (u_xx / rscale**2) * coeff2 + # is converted to u_xxx = u_yy * (rscale * coeff1) + + # u_xx * (rscale * coeff2) + row_rscale = sum(full_coeffs[irow]) + matrix_row = [] + for icol, coeff in row: + diff = row_rscale - sum(stored_identifiers[icol]) + matrix_row.append((icol, coeff * rscale**diff)) + matrix_rows.append((irow, matrix_row)) + + return defaultdict(lambda: [], matrix_rows) @memoize_method def _get_stored_ids_and_coeff_mat(self): @@ -321,7 +325,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): from six import iteritems for i, identifier in enumerate(self.get_full_coefficient_identifiers()): expr = self.try_get_recurrence_for_derivative( - identifier, identifiers_so_far, rscale=self._rscale_symbol) + identifier, identifiers_so_far) if expr is None: # Identifier should be stored @@ -353,7 +357,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): return stored_identifiers, coeff_matrix - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): """ :arg deriv: a tuple of integers identifying a derivative for which a recurrence is sought @@ -370,7 +374,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): class LaplaceDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): deriv = np.array(deriv, dtype=int) for dim in np.where(2 <= deriv)[0]: @@ -402,7 +406,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): super(HelmholtzDerivativeWrangler, self).__init__(order, dim) self.helmholtz_k_name = helmholtz_k_name - def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): + def try_get_recurrence_for_derivative(self, deriv, in_terms_of): deriv = np.array(deriv, dtype=int) for dim in np.where(2 <= deriv)[0]: @@ -426,7 +430,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): coeffs[needed_deriv] = -1 else: k = sym.Symbol(self.helmholtz_k_name) - coeffs[tuple(reduced_deriv)] = -k*k*rscale*rscale + coeffs[tuple(reduced_deriv)] = -k*k return coeffs # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 46fa9ee96c10b49a859871a91354e37c5cda28d9..964c64cc440c807c02b3f8811c6a3f0de6c84b04 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -126,7 +126,7 @@ class LinearRecurrenceBasedMiDerivativeTaker(MiDerivativeTaker): recurrence = ( self.wrangler.try_get_recurrence_for_derivative( - next_mi, self.cache_by_mi, rscale=1)) + next_mi, self.cache_by_mi)) if recurrence is not None: expr = Add(*tuple( diff --git a/sumpy/version.py b/sumpy/version.py index 0efa5f08fbea794b5b10e5f1fcb05451d9956433..a788c6d6cf61145b77b833ff172d88e7514e17ef 100644 --- a/sumpy/version.py +++ b/sumpy/version.py @@ -25,4 +25,4 @@ VERSION = (2016, 1) VERSION_STATUS = "beta1" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS -KERNEL_VERSION = 23 +KERNEL_VERSION = 24