From 6b096bea8dc3fb542d27998c9fe25de131ac53be Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 6 Aug 2024 12:49:52 -0500 Subject: [PATCH] outer: disallow non-object numpy arrays --- arraycontext/container/traversal.py | 16 +++++++++++----- test/test_arraycontext.py | 9 --------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 100f077..a7547df 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any: Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer` always returns a matrix). Here the definition of "scalar" includes - all non-array-container types and any scalar-like array container types - (including non-object numpy arrays). + all non-array-container types and any scalar-like array container types. If *a* and *b* are both array containers, the result will have the same type as *a*. If both are array containers and neither is an object array, they must @@ -968,12 +967,19 @@ def outer(a: Any, b: Any) -> Any: # This condition is whether "ndarrays should broadcast inside x". and NumpyObjectArray not in x.__class__._outer_bcast_types) + a_is_ndarray = isinstance(a, np.ndarray) + b_is_ndarray = isinstance(b, np.ndarray) + + if a_is_ndarray and a.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") + if b_is_ndarray and b.dtype != object: + raise TypeError("passing a non-object numpy array is not allowed") + if treat_as_scalar(a) or treat_as_scalar(b): return a*b - # After this point, "isinstance(o, ndarray)" means o is an object array. - elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + elif a_is_ndarray and b_is_ndarray: return np.outer(a, b) - elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + elif a_is_ndarray or b_is_ndarray: return map_array_container(lambda x: outer(x, b), a) else: if type(a) is not type(b): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ae65a3d..63387db 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1457,15 +1457,6 @@ def test_outer(actx_factory): b_bcast_dc_of_dofs.momentum), enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) - # Non-object numpy arrays should be treated as scalars - ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass)) - assert equal( - outer(ary_of_floats, b_bcast_dc_of_dofs), - ary_of_floats*b_bcast_dc_of_dofs) - assert equal( - outer(a_bcast_dc_of_dofs, ary_of_floats), - a_bcast_dc_of_dofs*ary_of_floats) - # }}} -- GitLab