diff --git a/grudge/op.py b/grudge/op.py index 2555156445b8b43b021009c3a334245a3a51739e..45595977c1db08600c14eaadcf8c4afbd949fe18 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -210,19 +210,19 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # {{{ common derivative "helpers" -def _div_helper(ambient_dim, component_div, scalar_type, vecs): +def _div_helper(ambient_dim, component_div, is_scalar, vecs): if not isinstance(vecs, np.ndarray): # vecs is not an object array -> treat as array container return map_array_container( - partial(_div_helper, ambient_dim, component_div, scalar_type), vecs) + partial(_div_helper, ambient_dim, component_div, is_scalar), vecs) assert vecs.dtype == object - if vecs.size and not isinstance(vecs[(0,)*vecs.ndim], scalar_type): + if vecs.size and not is_scalar(vecs[(0,)*vecs.ndim]): # vecs is an object array containing further object arrays # -> treat as array container return map_array_container( - partial(_div_helper, ambient_dim, component_div, scalar_type), vecs) + partial(_div_helper, ambient_dim, component_div, is_scalar), vecs) if vecs.shape[-1] != ambient_dim: raise ValueError("last/innermost dimension of *vecs* argument doesn't match " @@ -239,7 +239,7 @@ def _div_helper(ambient_dim, component_div, scalar_type, vecs): return result -def _grad_helper(ambient_dim, component_grad, scalar_type, vecs, nested): +def _grad_helper(ambient_dim, component_grad, is_scalar, vecs, nested): if isinstance(vecs, np.ndarray): # Occasionally, data structures coming from *mirgecom* will # contain empty object arrays as placeholders for fields. @@ -256,16 +256,16 @@ def _grad_helper(ambient_dim, component_grad, scalar_type, vecs, nested): # derivatives by stacking the results. grad = obj_array_vectorize( lambda el: _grad_helper( - ambient_dim, component_grad, scalar_type, el, nested=nested), vecs) + ambient_dim, component_grad, is_scalar, el, nested=nested), vecs) if nested: return grad else: return np.stack(grad, axis=0) - if not isinstance(vecs, scalar_type): + if not is_scalar(vecs): return map_array_container( partial( - _grad_helper, ambient_dim, component_grad, scalar_type, + _grad_helper, ambient_dim, component_grad, is_scalar, nested=nested), vecs) @@ -334,7 +334,7 @@ def local_grad( return _grad_helper( dcoll.ambient_dim, partial(_strong_scalar_grad, dcoll, dd_in), - DOFArray, + lambda v: isinstance(v, DOFArray), vec, nested=nested) @@ -392,7 +392,11 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainerT: local_d_dx(dcoll, i, vec_i) for i, vec_i in enumerate(vec)) - return _div_helper(dcoll.ambient_dim, component_div, DOFArray, vecs) + return _div_helper( + dcoll.ambient_dim, + component_div, + lambda v: isinstance(v, DOFArray), + vecs) # }}} @@ -494,7 +498,7 @@ def weak_local_grad( return _grad_helper( dcoll.ambient_dim, partial(_weak_scalar_grad, dcoll, dd_in), - DOFArray, + lambda v: isinstance(v, DOFArray), vecs, nested=nested) @@ -601,7 +605,11 @@ def weak_local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: weak_local_d_dx(dcoll, dd_in, i, vec_i) for i, vec_i in enumerate(vec)) - return _div_helper(dcoll.ambient_dim, component_div, DOFArray, vecs) + return _div_helper( + dcoll.ambient_dim, + component_div, + lambda v: isinstance(v, DOFArray), + vecs) # }}}