From 683d11d1eb02e33ee6963caa7d10222ea7e22a77 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Wed, 27 Apr 2022 13:15:13 -0500
Subject: [PATCH] assume numpy scalar types are scalars and use scalar_cls
 instead of is_scalar

---
 grudge/op.py    | 11 ++++-------
 grudge/tools.py | 21 +++++++++++++++------
 2 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index 18ccd266..1607cba6 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -270,8 +270,7 @@ def local_grad(
         f=partial(_strong_scalar_grad, dcoll, dd_in),
         in_shape=(),
         out_shape=(dcoll.ambient_dim,),
-        ary=vec, is_scalar=lambda v: isinstance(v, DOFArray),
-        return_nested=nested,)
+        ary=vec, scalar_cls=DOFArray, return_nested=nested,)
 
 
 def local_d_dx(
@@ -329,8 +328,7 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainerT:
             for i, vec_i in enumerate(vec)),
         in_shape=(dcoll.ambient_dim,),
         out_shape=(),
-        ary=vecs,
-        is_scalar=lambda v: isinstance(v, DOFArray))
+        ary=vecs, scalar_cls=DOFArray)
 
 # }}}
 
@@ -434,8 +432,7 @@ def weak_local_grad(
         f=partial(_weak_scalar_grad, dcoll, dd_in),
         in_shape=(),
         out_shape=(dcoll.ambient_dim,),
-        ary=vecs, is_scalar=lambda v: isinstance(v, DOFArray),
-        return_nested=nested)
+        ary=vecs, scalar_cls=DOFArray, return_nested=nested)
 
 
 def weak_local_d_dx(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT:
@@ -542,7 +539,7 @@ def weak_local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT:
             for i, vec_i in enumerate(vec)),
         in_shape=(dcoll.ambient_dim,),
         out_shape=(),
-        ary=vecs, is_scalar=lambda v: isinstance(v, DOFArray))
+        ary=vecs, scalar_cls=DOFArray)
 
 # }}}
 
diff --git a/grudge/tools.py b/grudge/tools.py
index 74889b88..246c21d8 100644
--- a/grudge/tools.py
+++ b/grudge/tools.py
@@ -320,19 +320,28 @@ def rec_map_subarrays(
         in_shape: Tuple[int, ...],
         out_shape: Tuple[int, ...],
         ary: ArrayOrContainerT, *,
-        is_scalar: Optional[Callable[[Any], bool]] = None,
+        scalar_cls: Optional[Union[type, Tuple[type]]] = None,
         return_nested: bool = False) -> ArrayOrContainerT:
     r"""
     Like :func:`map_subarrays`, but with support for
     :class:`arraycontext.ArrayContainer`\ s.
 
-    :param is_scalar: a function that indicates whether a given object is to be
-        treated as a scalar or not.
+    :param scalar_cls: An array container of this type will be considered a scalar
+        and arrays of it will be passed to *f* without further destructuring.
     """
+    if scalar_cls is not None:
+        def is_scalar(x):
+            return np.isscalar(x) or isinstance(x, scalar_cls)
+    else:
+        def is_scalar(x):
+            return np.isscalar(x)
+
     def is_array_of_scalars(a):
         return (
-            isinstance(a, np.ndarray) and a.dtype == object
-            and all(is_scalar(a[idx]) for idx in np.ndindex(a.shape)))
+            isinstance(a, np.ndarray)
+            and (
+                a.dtype != object
+                or all(is_scalar(a[idx]) for idx in np.ndindex(a.shape))))
 
     if is_scalar(ary) or is_array_of_scalars(ary):
         return map_subarrays(
@@ -341,7 +350,7 @@ def rec_map_subarrays(
         from arraycontext import map_array_container
         return map_array_container(
             partial(
-                rec_map_subarrays, f, in_shape, out_shape, is_scalar=is_scalar,
+                rec_map_subarrays, f, in_shape, out_shape, scalar_cls=scalar_cls,
                 return_nested=return_nested),
             ary)
 
-- 
GitLab