diff --git a/sumpy/tools.py b/sumpy/tools.py index 88ff5241799c210a522e8b54935ff0d1ae9dde47..a18358c939e9932340058487c6af20efc4951fe3 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -65,20 +65,23 @@ class MiDerivativeTaker(object): return np.array(a, dtype=int) - np.array(b, dtype=int) def diff(self, mi): - closest_mi = min( - (other_mi - for other_mi in self.cache_by_mi.keys() - if (np.array(mi) >= np.array(other_mi)).all()), - key=lambda other_mi: sum(self.mi_dist(mi, other_mi))) - - expr = self.cache_by_mi[closest_mi] - current_mi = np.array(closest_mi, dtype=int) - for idx, (mi_i, vec_i) in enumerate( - zip(self.mi_dist(mi, closest_mi), self.var_list)): - for i in range(1, 1 + mi_i): - current_mi[idx] += 1 - expr = expr.diff(vec_i) - self.cache_by_mi[tuple(current_mi)] = expr + try: + expr = self.cache_by_mi[mi] + except KeyError: + closest_mi = min( + (other_mi + for other_mi in self.cache_by_mi.keys() + if (np.array(mi) >= np.array(other_mi)).all()), + key=lambda other_mi: sum(self.mi_dist(mi, other_mi))) + + expr = self.cache_by_mi[closest_mi] + current_mi = np.array(closest_mi, dtype=int) + for idx, (mi_i, vec_i) in enumerate( + zip(self.mi_dist(mi, closest_mi), self.var_list)): + for i in range(1, 1 + mi_i): + current_mi[idx] += 1 + expr = expr.diff(vec_i) + self.cache_by_mi[tuple(current_mi)] = expr return expr