From 6b096bea8dc3fb542d27998c9fe25de131ac53be Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 6 Aug 2024 12:49:52 -0500
Subject: [PATCH] outer: disallow non-object numpy arrays

---
 arraycontext/container/traversal.py | 16 +++++++++++-----
 test/test_arraycontext.py           |  9 ---------
 2 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 100f077..a7547df 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any:
     Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
     object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
     always returns a matrix). Here the definition of "scalar" includes
-    all non-array-container types and any scalar-like array container types
-    (including non-object numpy arrays).
+    all non-array-container types and any scalar-like array container types.
 
     If *a* and *b* are both array containers, the result will have the same type
     as *a*. If both are array containers and neither is an object array, they must
@@ -968,12 +967,19 @@ def outer(a: Any, b: Any) -> Any:
                 # This condition is whether "ndarrays should broadcast inside x".
                 and NumpyObjectArray not in x.__class__._outer_bcast_types)
 
+    a_is_ndarray = isinstance(a, np.ndarray)
+    b_is_ndarray = isinstance(b, np.ndarray)
+
+    if a_is_ndarray and a.dtype != object:
+        raise TypeError("passing a non-object numpy array is not allowed")
+    if b_is_ndarray and b.dtype != object:
+        raise TypeError("passing a non-object numpy array is not allowed")
+
     if treat_as_scalar(a) or treat_as_scalar(b):
         return a*b
-    # After this point, "isinstance(o, ndarray)" means o is an object array.
-    elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+    elif a_is_ndarray and b_is_ndarray:
         return np.outer(a, b)
-    elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
+    elif a_is_ndarray or b_is_ndarray:
         return map_array_container(lambda x: outer(x, b), a)
     else:
         if type(a) is not type(b):
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index ae65a3d..63387db 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1457,15 +1457,6 @@ def test_outer(actx_factory):
                 b_bcast_dc_of_dofs.momentum),
             enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))
 
-    # Non-object numpy arrays should be treated as scalars
-    ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass))
-    assert equal(
-        outer(ary_of_floats, b_bcast_dc_of_dofs),
-        ary_of_floats*b_bcast_dc_of_dofs)
-    assert equal(
-        outer(a_bcast_dc_of_dofs, ary_of_floats),
-        a_bcast_dc_of_dofs*ary_of_floats)
-
 # }}}
 
 
-- 
GitLab