From f307a04ee4304673edff04e0627ebcd376c05a98 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 27 Nov 2024 13:04:58 -0600
Subject: [PATCH] {to,from}_numpy: Use overloads for more precise type info

---
 arraycontext/context.py             | 18 +++++++++++++++++-
 arraycontext/impl/numpy/__init__.py | 27 ++++++++++++++++++++++-----
 arraycontext/impl/pytato/utils.py   |  5 +----
 3 files changed, 40 insertions(+), 10 deletions(-)

diff --git a/arraycontext/context.py b/arraycontext/context.py
index 0d0595c..5c7651c 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -156,7 +156,7 @@ THE SOFTWARE.
 
 from abc import ABC, abstractmethod
 from collections.abc import Callable, Mapping
-from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload
 from warnings import warn
 
 import numpy as np
@@ -320,6 +320,14 @@ class ArrayContext(ABC):
 
         return self.np.zeros(shape, dtype)
 
+    @overload
+    def from_numpy(self, array: np.ndarray) -> Array:
+        ...
+
+    @overload
+    def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
+        ...
+
     @abstractmethod
     def from_numpy(self,
                    array: NumpyOrContainerOrScalar
@@ -333,6 +341,14 @@ class ArrayContext(ABC):
             intact.
         """
 
+    @overload
+    def to_numpy(self, array: Array) -> np.ndarray:
+        ...
+
+    @overload
+    def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
+        ...
+
     @abstractmethod
     def to_numpy(self,
                  array: ArrayOrContainerOrScalar
diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py
index c2f884a..f9d6c54 100644
--- a/arraycontext/impl/numpy/__init__.py
+++ b/arraycontext/impl/numpy/__init__.py
@@ -1,7 +1,4 @@
-from __future__ import annotations
-
-
-__doc__ = """
+"""
 .. currentmodule:: arraycontext
 
 A :mod:`numpy`-based array context.
@@ -9,6 +6,9 @@ A :mod:`numpy`-based array context.
 .. autoclass:: NumpyArrayContext
 """
 
+from __future__ import annotations
+
+
 __copyright__ = """
 Copyright (C) 2021 University of Illinois Board of Trustees
 """
@@ -33,7 +33,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any
+from typing import Any, overload
 
 import numpy as np
 
@@ -46,6 +46,7 @@ from arraycontext.context import (
     ArrayContext,
     ArrayOrContainerOrScalar,
     ArrayOrContainerOrScalarT,
+    ContainerOrScalarT,
     NumpyOrContainerOrScalar,
     UntransformedCodeWarning,
 )
@@ -84,11 +85,27 @@ class NumpyArrayContext(ArrayContext):
     def clone(self):
         return type(self)()
 
+    @overload
+    def from_numpy(self, array: np.ndarray) -> Array:
+        ...
+
+    @overload
+    def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
+        ...
+
     def from_numpy(self,
                    array: NumpyOrContainerOrScalar
                    ) -> ArrayOrContainerOrScalar:
         return array
 
+    @overload
+    def to_numpy(self, array: Array) -> np.ndarray:
+        ...
+
+    @overload
+    def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
+        ...
+
     def to_numpy(self,
                  array: ArrayOrContainerOrScalar
                  ) -> NumpyOrContainerOrScalar:
diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
index c031e29..6441527 100644
--- a/arraycontext/impl/pytato/utils.py
+++ b/arraycontext/impl/pytato/utils.py
@@ -163,10 +163,7 @@ class TransferFromNumpyMapper(CopyMapper):
 
         # https://github.com/pylint-dev/pylint/issues/3893
         # pylint: disable=unexpected-keyword-arg
-        # type-ignore: discussed at
-        # https://github.com/inducer/arraycontext/pull/289#discussion_r1855523967
-        # possibly related: https://github.com/python/mypy/issues/17375
-        return DataWrapper(  # type: ignore[call-arg]
+        return DataWrapper(
             data=new_dw.data,
             shape=expr.shape,
             axes=expr.axes,
-- 
GitLab