diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 8c09226fde8df846c6cfefb1af5ef8a47971e772..d6d7bac92094936f5086d4d05e1a3aa3def0c46c 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -59,7 +59,7 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", "sqrt", "exp"] if name in pt_funcs: - import pytato as pt + import pytato as pt # type: ignore from functools import partial return partial(rec_map_array_container, getattr(pt, name)) @@ -264,7 +264,7 @@ class PytatoArrayContext(ArrayContext): return cl_array.get(queue=self.queue) def call_loopy(self, program, **kwargs): - from pytato.loopy import call_loopy + from pytato.loopy import call_loopy # type: ignore import pyopencl.array as cla entrypoint, = set(program.callables_table) @@ -305,7 +305,8 @@ class PytatoArrayContext(ArrayContext): # }}} def compile(self, f: Callable[[Any], Any], - inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[..., Any]: + inputs_like: Tuple[Union[Number, np.ndarray], ...]) -> Callable[ + ..., Any]: from pytools.obj_array import flat_obj_array from arraycontext.impl import _is_meshmode_dofarray from meshmode.dof_array import DOFArray