diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 89c4e885b34546d81c15f6d7de426241de27ebce..28910150feb4d0124cd633a00731cb2ff39abf37 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,16 +30,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Dict, Sequence, Union +from typing import Any, Dict, Sequence, Union import numpy as np import loopy as lp from pytools.tag import Tag +from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ArrayContext -from arraycontext.container.traversal import ( - rec_map_array_container, with_array_context) + + +class NumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype != object + + +class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): + pass class NumpyArrayContext(ArrayContext): @@ -53,7 +61,7 @@ class NumpyArrayContext(ArrayContext): self._loopy_transform_cache: \ Dict[lp.TranslationUnit, lp.TranslationUnit] = {} - self.array_types = (np.ndarray,) + array_types = (NumpyNonObjectArray,) def _get_fake_numpy_namespace(self): from .fake_numpy import NumpyFakeNumpyNamespace