From f0ed545656830ec0e78c689fdfd1e7bb29ac7741 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 30 Dec 2016 19:23:08 -0600 Subject: [PATCH] MiDerivativeTaker: Look in the cache before trying to take the derivative. --- sumpy/tools.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/sumpy/tools.py b/sumpy/tools.py index 88ff5241..a18358c9 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -65,20 +65,23 @@ class MiDerivativeTaker(object): return np.array(a, dtype=int) - np.array(b, dtype=int) def diff(self, mi): - closest_mi = min( - (other_mi - for other_mi in self.cache_by_mi.keys() - if (np.array(mi) >= np.array(other_mi)).all()), - key=lambda other_mi: sum(self.mi_dist(mi, other_mi))) - - expr = self.cache_by_mi[closest_mi] - current_mi = np.array(closest_mi, dtype=int) - for idx, (mi_i, vec_i) in enumerate( - zip(self.mi_dist(mi, closest_mi), self.var_list)): - for i in range(1, 1 + mi_i): - current_mi[idx] += 1 - expr = expr.diff(vec_i) - self.cache_by_mi[tuple(current_mi)] = expr + try: + expr = self.cache_by_mi[mi] + except KeyError: + closest_mi = min( + (other_mi + for other_mi in self.cache_by_mi.keys() + if (np.array(mi) >= np.array(other_mi)).all()), + key=lambda other_mi: sum(self.mi_dist(mi, other_mi))) + + expr = self.cache_by_mi[closest_mi] + current_mi = np.array(closest_mi, dtype=int) + for idx, (mi_i, vec_i) in enumerate( + zip(self.mi_dist(mi, closest_mi), self.var_list)): + for i in range(1, 1 + mi_i): + current_mi[idx] += 1 + expr = expr.diff(vec_i) + self.cache_by_mi[tuple(current_mi)] = expr return expr -- GitLab