From 2334aa00036c50b942a6bd89381d427e9ad98786 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Wed, 27 Apr 2022 10:18:21 -0500
Subject: [PATCH] Revert "Rework *map_subarrays docs, add types, use leaf_cls"

This reverts commit 47d2d6b0dc76bdc01be7168cc79e13da72540820.
---
 grudge/op.py    | 14 +++++++++----
 grudge/tools.py | 53 +++++++++++++++++++++----------------------------
 2 files changed, 33 insertions(+), 34 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index bdb77cde..5b7c10f2 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -270,7 +270,9 @@ def local_grad(
         f=partial(_strong_scalar_grad, dcoll, dd_in),
         in_shape=(),
         out_shape=(dcoll.ambient_dim,),
-        ary=vec, leaf_cls=DOFArray, return_nested=nested,)
+        is_scalar=lambda v: isinstance(v, DOFArray),
+        return_nested=nested,
+        ary=vec)
 
 
 def local_d_dx(
@@ -328,7 +330,8 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainerT:
             for i, vec_i in enumerate(vec)),
         in_shape=(dcoll.ambient_dim,),
         out_shape=(),
-        ary=vecs, leaf_cls=DOFArray)
+        is_scalar=lambda v: isinstance(v, DOFArray),
+        ary=vecs)
 
 # }}}
 
@@ -432,7 +435,9 @@ def weak_local_grad(
         f=partial(_weak_scalar_grad, dcoll, dd_in),
         in_shape=(),
         out_shape=(dcoll.ambient_dim,),
-        ary=vecs, leaf_cls=DOFArray, return_nested=nested)
+        is_scalar=lambda v: isinstance(v, DOFArray),
+        return_nested=nested,
+        ary=vecs)
 
 
 def weak_local_d_dx(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT:
@@ -539,7 +544,8 @@ def weak_local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT:
             for i, vec_i in enumerate(vec)),
         in_shape=(dcoll.ambient_dim,),
         out_shape=(),
-        ary=vecs, leaf_cls=DOFArray)
+        is_scalar=lambda v: isinstance(v, DOFArray),
+        ary=vecs)
 
 # }}}
 
diff --git a/grudge/tools.py b/grudge/tools.py
index ab7f5e02..72aaae8d 100644
--- a/grudge/tools.py
+++ b/grudge/tools.py
@@ -28,7 +28,6 @@ THE SOFTWARE.
 
 import numpy as np
 from pytools import levi_civita
-from typing import Tuple, Callable, Optional
 from functools import partial
 
 
@@ -239,20 +238,12 @@ def build_jacobian(actx, f, base_state, stepsize):
     return mat
 
 
-def map_subarrays(
-        f: Callable[[np.ndarray], np.ndarray],
-        in_shape: Tuple[int, ...], out_shape: Tuple[int, ...],
-        ary: np.ndarray, *, return_nested=False) -> np.ndarray:
+def map_subarrays(f, in_shape, out_shape, ary, *, return_nested=False):
     """
-    Apply a function *f* to subarrrays of shape *in_shape* of an
-    :class:`numpy.ndarray`, typically (but not necessarily) of dtype
-    :class:`object`. Return an :class:`numpy.ndarray` of the same dtype,
-    with the corresponding subarrays replaced by the return values of *f*,
-    and with the shape adapted to reflect *out_shape*.
+    Map a function *f* over a :class:`numpy.ndarray`, applying it to subarrays
+    of shape *in_shape* individually.
 
-    Similar to :class:`numpy.vectorize`.
-
-    *Example 1:* given a function *f* that maps arrays of shape ``(2, 2)`` to scalars
+    Example 1: given a function *f* that maps arrays of shape ``(2, 2)`` to scalars
     and an input array *ary* of shape ``(3, 2, 2)``, the call::
 
         map_subarrays(f, (2, 2), (), ary)
@@ -260,7 +251,7 @@ def map_subarrays(
     will produce an array of shape ``(3,)`` containing the results of calling *f* on
     the 3 subarrays of shape ``(2, 2)`` in *ary*.
 
-    *Example 2:* given a function *f* that maps arrays of shape ``(2,)`` to arrays of
+    Example 2: given a function *f* that maps arrays of shape ``(2,)`` to arrays of
     shape ``(2, 2)`` and an input array *ary* of shape ``(3, 2)``, the call::
 
         map_subarrays(f, (2,), (2, 2), ary)
@@ -303,39 +294,41 @@ def map_subarrays(
             in_slice = tuple(slice(0, n) for n in in_shape)
             out_slice = tuple(slice(0, n) for n in out_shape)
             if return_nested:
-                result = np.empty(base_shape, dtype=ary.dtype)
+                result = np.empty(base_shape, dtype=object)
                 for idx in np.ndindex(base_shape):
                     result[idx] = f(ary[idx + in_slice])
                 return result
             else:
-                result = np.empty(base_shape + out_shape, dtype=ary.dtype)
+                result = np.empty(base_shape + out_shape, dtype=object)
                 for idx in np.ndindex(base_shape):
                     result[idx + out_slice] = f(ary[idx + in_slice])
                 return result
 
 
 def rec_map_subarrays(
-        f: Callable[[np.ndarray], np.ndarray],
-        in_shape: Tuple[int, ...],
-        out_shape: Tuple[int, ...],
-        ary: np.ndarray, *,
-        leaf_cls: Optional[type] = None,
-        return_nested: bool = False) -> np.ndarray:
-    r"""
-    Like :func:`map_subarrays`, but with support for
-    :class:`arraycontext.ArrayContainer`\ s.
-
-    :param leaf_cls: An array container of this type will be considered a leaf
-        and will be passed to *f* without further destructuring.
+        f, in_shape, out_shape, is_scalar, ary, *, return_nested=False):
     """
-    if type(ary) is leaf_cls or isinstance(ary, np.ndarray):
+    Map a function *f* over an object *ary*, applying it to any subarrays
+    of shape *in_shape* individually.
+
+    Array container version of :func:`map_subarrays`.
+
+    :param is_scalar: a function that indicates whether a given object is to be
+        treated as a scalar or not.
+    """
+    def is_array_of_scalars(a):
+        return (
+            isinstance(a, np.ndarray) and a.dtype == object
+            and all(is_scalar(a[idx]) for idx in np.ndindex(a.shape)))
+
+    if is_scalar(ary) or is_array_of_scalars(ary):
         return map_subarrays(
             f, in_shape, out_shape, ary, return_nested=return_nested)
     else:
         from arraycontext import map_array_container
         return map_array_container(
             partial(
-                rec_map_subarrays, f, in_shape, out_shape, leaf_cls=leaf_cls,
+                rec_map_subarrays, f, in_shape, out_shape, is_scalar,
                 return_nested=return_nested),
             ary)
 
-- 
GitLab