From afb4b96a013036f6d74c1ead353ed83bfdbe9cb4 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 16:21:23 -0500 Subject: [PATCH] Numpy actx: Narrow array_types to non-obj arrays --- arraycontext/impl/numpy/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 89c4e88..2891015 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,16 +30,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Dict, Sequence, Union +from typing import Any, Dict, Sequence, Union import numpy as np import loopy as lp from pytools.tag import Tag +from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ArrayContext -from arraycontext.container.traversal import ( - rec_map_array_container, with_array_context) + + +class NumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype != object + + +class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): + pass class NumpyArrayContext(ArrayContext): @@ -53,7 +61,7 @@ class NumpyArrayContext(ArrayContext): self._loopy_transform_cache: \ Dict[lp.TranslationUnit, lp.TranslationUnit] = {} - self.array_types = (np.ndarray,) + array_types = (NumpyNonObjectArray,) def _get_fake_numpy_namespace(self): from .fake_numpy import NumpyFakeNumpyNamespace -- GitLab