diff --git a/sumpy/tools.py b/sumpy/tools.py index daad385301a6be0d5a4cecbff9f26b83b94f5d0a..25addd1851e4893a9c459a2352f0921d7aced67b 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -107,11 +107,11 @@ Profiling # {{{ multi_index helpers -def add_mi(mi1, mi2): +def add_mi(mi1: Sequence[int], mi2: Sequence[int]) -> Tuple[int, ...]: return tuple([mi1i + mi2i for mi1i, mi2i in zip(mi1, mi2)]) -def mi_factorial(mi): +def mi_factorial(mi: Sequence[int]) -> int: import math result = 1 for mi_i in mi: @@ -119,19 +119,23 @@ def mi_factorial(mi): return result -def mi_increment_axis(mi, axis, increment): +def mi_increment_axis( + mi: Sequence[int], axis: int, increment: int + ) -> Tuple[int, ...]: new_mi = list(mi) new_mi[axis] += increment return tuple(new_mi) -def mi_set_axis(mi, axis, value): +def mi_set_axis(mi: Sequence[int], axis: int, value: int) -> Tuple[int, ...]: new_mi = list(mi) new_mi[axis] = value return tuple(new_mi) -def mi_power(vector, mi, evaluate=True): +def mi_power( + vector: Sequence[T], mi: Sequence[int], + evaluate: bool = True) -> T: result = 1 for mi_i, vec_i in zip(mi, vector): if mi_i == 1: @@ -147,8 +151,8 @@ def add_to_sac(sac, expr): if sac is None: return expr - if isinstance(expr, (numbers.Number, sym.Number, int, - float, complex, sym.Symbol)): + from numbers import Number + if isinstance(expr, (Number, sym.Number, sym.Symbol)): return expr name = sac.assign_temp("temp", expr) @@ -280,7 +284,7 @@ class KernelComputation(ABC): target_kernels: List["Kernel"], source_kernels: List["Kernel"], strength_usage: Optional[List[int]] = None, - value_dtypes: Optional[List["np.dtype"]] = None, + value_dtypes: Optional[List["np.dtype[Any]"]] = None, name: Optional[str] = None, device: Optional[Any] = None) -> None: """ @@ -913,7 +917,11 @@ def _get_fft_backend(queue) -> FFTBackend: return FFTBackend.pyvkfft -def get_opencl_fft_app(queue, shape, dtype, inverse): +def get_opencl_fft_app( + queue: "cl.CommandQueue", + shape: Tuple[int, ...], + dtype: "np.dtype[Any]", + inverse: bool) -> Any: """Setup an object for out-of-place FFT on with given shape and dtype on given queue. """ @@ -932,7 +940,12 @@ def get_opencl_fft_app(queue, shape, dtype, inverse): raise RuntimeError(f"Unsupported FFT backend {backend}") -def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None): +def run_opencl_fft( + fft_app: Tuple[Any, FFTBackend], + queue: "cl.CommandQueue", + input_vec: Array, + inverse: bool = False, + wait_for: List["cl.Event"] = None) -> Tuple["cl.Event", Array]: """Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent` that indicate the end and start of the operations carried out and the output vector.