diff --git a/sumpy/tools.py b/sumpy/tools.py index bec9ea9ba14e1a4a5b7d59c0be57d1ea501bfa68..89192f78b0af71fb03adef00014a96918b3b1a89 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -46,6 +46,7 @@ import numpy as np import sumpy.symbolic as sym import loopy as lp +from typing import Dict, Tuple, Any import logging logger = logging.getLogger(__name__) @@ -394,6 +395,9 @@ class HelmholtzDerivativeTaker(RadialDerivativeTaker): return expr +DerivativeCoeffDict = Dict[Tuple[int], Any] + + @tag_dataclass class DifferentiatedExprDerivativeTaker: """Implements the :class:`ExprDerivativeTaker` interface @@ -411,7 +415,7 @@ class DifferentiatedExprDerivativeTaker: base expression. """ taker: ExprDerivativeTaker - derivative_coeff_dict: dict + derivative_coeff_dict: DerivativeCoeffDict def diff(self, mi, save_intermediate=lambda x: x): # By passing `rscale` to the derivative taker we are taking a scaled @@ -432,7 +436,8 @@ class DifferentiatedExprDerivativeTaker: return result * save_intermediate(1 / self.taker.rscale ** max_order) -def diff_derivative_coeff_dict(derivative_coeff_dict, variable_idx, variables): +def diff_derivative_coeff_dict(derivative_coeff_dict: DerivativeCoeffDict, + variable_idx, variables): """Differentiate a derivative transformation dictionary given by *derivative_coeff_dict* using the variable given by **variable_idx** and return a new derivative transformation dictionary.