diff --git a/grudge/op.py b/grudge/op.py index 18ccd26619644deefd43246b6b0f611b099ff9d3..1607cba6154eca2edaf7cad1bb33fa7427ed2d7f 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -270,8 +270,7 @@ def local_grad( f=partial(_strong_scalar_grad, dcoll, dd_in), in_shape=(), out_shape=(dcoll.ambient_dim,), - ary=vec, is_scalar=lambda v: isinstance(v, DOFArray), - return_nested=nested,) + ary=vec, scalar_cls=DOFArray, return_nested=nested,) def local_d_dx( @@ -329,8 +328,7 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainerT: for i, vec_i in enumerate(vec)), in_shape=(dcoll.ambient_dim,), out_shape=(), - ary=vecs, - is_scalar=lambda v: isinstance(v, DOFArray)) + ary=vecs, scalar_cls=DOFArray) # }}} @@ -434,8 +432,7 @@ def weak_local_grad( f=partial(_weak_scalar_grad, dcoll, dd_in), in_shape=(), out_shape=(dcoll.ambient_dim,), - ary=vecs, is_scalar=lambda v: isinstance(v, DOFArray), - return_nested=nested) + ary=vecs, scalar_cls=DOFArray, return_nested=nested) def weak_local_d_dx(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: @@ -542,7 +539,7 @@ def weak_local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: for i, vec_i in enumerate(vec)), in_shape=(dcoll.ambient_dim,), out_shape=(), - ary=vecs, is_scalar=lambda v: isinstance(v, DOFArray)) + ary=vecs, scalar_cls=DOFArray) # }}} diff --git a/grudge/tools.py b/grudge/tools.py index 74889b8803aa15cf787301225fdfd9e7da982609..246c21d8167f1efca0b017d343734ad5757b320a 100644 --- a/grudge/tools.py +++ b/grudge/tools.py @@ -320,19 +320,28 @@ def rec_map_subarrays( in_shape: Tuple[int, ...], out_shape: Tuple[int, ...], ary: ArrayOrContainerT, *, - is_scalar: Optional[Callable[[Any], bool]] = None, + scalar_cls: Optional[Union[type, Tuple[type]]] = None, return_nested: bool = False) -> ArrayOrContainerT: r""" Like :func:`map_subarrays`, but with support for :class:`arraycontext.ArrayContainer`\ s. - :param is_scalar: a function that indicates whether a given object is to be - treated as a scalar or not. + :param scalar_cls: An array container of this type will be considered a scalar + and arrays of it will be passed to *f* without further destructuring. """ + if scalar_cls is not None: + def is_scalar(x): + return np.isscalar(x) or isinstance(x, scalar_cls) + else: + def is_scalar(x): + return np.isscalar(x) + def is_array_of_scalars(a): return ( - isinstance(a, np.ndarray) and a.dtype == object - and all(is_scalar(a[idx]) for idx in np.ndindex(a.shape))) + isinstance(a, np.ndarray) + and ( + a.dtype != object + or all(is_scalar(a[idx]) for idx in np.ndindex(a.shape)))) if is_scalar(ary) or is_array_of_scalars(ary): return map_subarrays( @@ -341,7 +350,7 @@ def rec_map_subarrays( from arraycontext import map_array_container return map_array_container( partial( - rec_map_subarrays, f, in_shape, out_shape, is_scalar=is_scalar, + rec_map_subarrays, f, in_shape, out_shape, scalar_cls=scalar_cls, return_nested=return_nested), ary)