Skip to content
Snippets Groups Projects
Commit baaf7ef7 authored by Isuru Fernando's avatar Isuru Fernando
Browse files

Remove rscale from try_get_recurrence_for_derivative

parent a9051281
No related branches found
No related tags found
No related merge requests found
...@@ -246,7 +246,6 @@ def _spmv(spmat, x, sparse_vectors): ...@@ -246,7 +246,6 @@ def _spmv(spmat, x, sparse_vectors):
class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler):
_rscale_symbol = sp.Symbol("_sumpy_rscale_placeholder")
def get_coefficient_identifiers(self): def get_coefficient_identifiers(self):
return self.stored_identifiers return self.stored_identifiers
...@@ -295,15 +294,21 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): ...@@ -295,15 +294,21 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler):
""" """
stored_identifiers, coeff_matrix = self._get_stored_ids_and_coeff_mat() stored_identifiers, coeff_matrix = self._get_stored_ids_and_coeff_mat()
# substitute actual rscale for internal placeholder full_coeffs = self.get_full_coefficient_identifiers()
return defaultdict(lambda: [], matrix_rows = []
((irow, [ for irow, row in six.iteritems(coeff_matrix):
(icol, # For eg: (u_xxx / rscale**3) = (u_yy / rscale**2) * coeff1 +
coeff.xreplace({self._rscale_symbol: rscale}) # (u_xx / rscale**2) * coeff2
if isinstance(coeff, sp.Basic) # is converted to u_xxx = u_yy * (rscale * coeff1) +
else coeff) # u_xx * (rscale * coeff2)
for icol, coeff in row]) row_rscale = sum(full_coeffs[irow])
for irow, row in six.iteritems(coeff_matrix))) matrix_row = []
for icol, coeff in row:
diff = row_rscale - sum(stored_identifiers[icol])
matrix_row.append((icol, coeff * rscale**diff))
matrix_rows.append((irow, matrix_row))
return defaultdict(lambda: [], matrix_rows)
@memoize_method @memoize_method
def _get_stored_ids_and_coeff_mat(self): def _get_stored_ids_and_coeff_mat(self):
...@@ -321,7 +326,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): ...@@ -321,7 +326,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler):
from six import iteritems from six import iteritems
for i, identifier in enumerate(self.get_full_coefficient_identifiers()): for i, identifier in enumerate(self.get_full_coefficient_identifiers()):
expr = self.try_get_recurrence_for_derivative( expr = self.try_get_recurrence_for_derivative(
identifier, identifiers_so_far, rscale=self._rscale_symbol) identifier, identifiers_so_far)
if expr is None: if expr is None:
# Identifier should be stored # Identifier should be stored
...@@ -353,7 +358,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): ...@@ -353,7 +358,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler):
return stored_identifiers, coeff_matrix return stored_identifiers, coeff_matrix
def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): def try_get_recurrence_for_derivative(self, deriv, in_terms_of):
""" """
:arg deriv: a tuple of integers identifying a derivative for which :arg deriv: a tuple of integers identifying a derivative for which
a recurrence is sought a recurrence is sought
...@@ -370,7 +375,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler): ...@@ -370,7 +375,7 @@ class LinearRecurrenceBasedDerivativeWrangler(DerivativeWrangler):
class LaplaceDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): class LaplaceDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler):
def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): def try_get_recurrence_for_derivative(self, deriv, in_terms_of):
deriv = np.array(deriv, dtype=int) deriv = np.array(deriv, dtype=int)
for dim in np.where(2 <= deriv)[0]: for dim in np.where(2 <= deriv)[0]:
...@@ -402,7 +407,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): ...@@ -402,7 +407,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler):
super(HelmholtzDerivativeWrangler, self).__init__(order, dim) super(HelmholtzDerivativeWrangler, self).__init__(order, dim)
self.helmholtz_k_name = helmholtz_k_name self.helmholtz_k_name = helmholtz_k_name
def try_get_recurrence_for_derivative(self, deriv, in_terms_of, rscale): def try_get_recurrence_for_derivative(self, deriv, in_terms_of):
deriv = np.array(deriv, dtype=int) deriv = np.array(deriv, dtype=int)
for dim in np.where(2 <= deriv)[0]: for dim in np.where(2 <= deriv)[0]:
...@@ -426,7 +431,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler): ...@@ -426,7 +431,7 @@ class HelmholtzDerivativeWrangler(LinearRecurrenceBasedDerivativeWrangler):
coeffs[needed_deriv] = -1 coeffs[needed_deriv] = -1
else: else:
k = sym.Symbol(self.helmholtz_k_name) k = sym.Symbol(self.helmholtz_k_name)
coeffs[tuple(reduced_deriv)] = -k*k*rscale*rscale coeffs[tuple(reduced_deriv)] = -k*k
return coeffs return coeffs
# }}} # }}}
......
...@@ -126,7 +126,7 @@ class LinearRecurrenceBasedMiDerivativeTaker(MiDerivativeTaker): ...@@ -126,7 +126,7 @@ class LinearRecurrenceBasedMiDerivativeTaker(MiDerivativeTaker):
recurrence = ( recurrence = (
self.wrangler.try_get_recurrence_for_derivative( self.wrangler.try_get_recurrence_for_derivative(
next_mi, self.cache_by_mi, rscale=1)) next_mi, self.cache_by_mi))
if recurrence is not None: if recurrence is not None:
expr = Add(*tuple( expr = Add(*tuple(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment