diff --git a/grudge/reductions.py b/grudge/reductions.py index d9b0a6bad37a84bd33f4482dea8d6ba939e5c40f..026f76d3afd384ea9c491fc5256fdc91eb69ae41 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -137,10 +137,6 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the nodal sum. """ - if isinstance(vec, np.ndarray): - return sum(nodal_sum(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) - comm = dcoll.mpi_communicator if comm is None: return nodal_sum_loc(dcoll, dd, vec) @@ -159,6 +155,10 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the rank-local nodal sum. """ + if isinstance(vec, np.ndarray): + return sum(nodal_sum_loc(dcoll, dd, vec[idx]) + for idx in np.ndindex(vec.shape)) + actx = vec.array_context return sum([actx.np.sum(grp_ary) for grp_ary in vec]) @@ -171,10 +171,6 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the nodal minimum. """ - if isinstance(vec, np.ndarray): - return min(nodal_min(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) - comm = dcoll.mpi_communicator if comm is None: return nodal_min_loc(dcoll, dd, vec) @@ -194,6 +190,10 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the rank-local nodal minimum. """ + if isinstance(vec, np.ndarray): + return min(nodal_min_loc(dcoll, dd, vec[idx]) + for idx in np.ndindex(vec.shape)) + actx = vec.array_context return reduce(lambda acc, grp_ary: actx.np.minimum(acc, actx.np.min(grp_ary)), vec, np.inf) @@ -207,10 +207,6 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the nodal maximum. """ - if isinstance(vec, np.ndarray): - return max(nodal_max(dcoll, dd, vec[idx]) - for idx in np.ndindex(vec.shape)) - comm = dcoll.mpi_communicator if comm is None: return nodal_max_loc(dcoll, dd, vec) @@ -230,6 +226,10 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> float: :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a scalar denoting the rank-local nodal maximum. """ + if isinstance(vec, np.ndarray): + return max(nodal_max_loc(dcoll, dd, vec[idx]) + for idx in np.ndindex(vec.shape)) + actx = vec.array_context return reduce(lambda acc, grp_ary: actx.np.maximum(acc, actx.np.max(grp_ary)), vec, -np.inf)