diff --git a/grudge/reductions.py b/grudge/reductions.py index c5a6fbea647c6228ecb72a7f6ff0b341b52df112..8049358b784d21bfb6eb5fb24a1cf3ce0a46fb00 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -165,18 +165,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: actx = vec.array_context - if hasattr(actx, "_force_device_scalars"): - force_device_scalars = actx._force_device_scalars - else: - force_device_scalars = True - - def device_sum(x): - result = actx.np.sum(x) - if not force_device_scalars: - result = actx.from_numpy(np.array(result)) - return result - - return sum([device_sum(grp_ary) for grp_ary in vec]) + return sum([actx.np.sum(grp_ary) for grp_ary in vec]) def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any: @@ -228,9 +217,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: result = as_device_scalar(result) return result - return reduce( + result = reduce( lambda acc, grp_ary: actx.np.minimum(acc, device_min(grp_ary)), vec, as_device_scalar(np.inf)) + if not force_device_scalars: + result = actx.to_numpy(result)[()] + + return result def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any: @@ -282,9 +275,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: result = as_device_scalar(result) return result - return reduce( + result = reduce( lambda acc, grp_ary: actx.np.maximum(acc, device_max(grp_ary)), vec, as_device_scalar(-np.inf)) + if not force_device_scalars: + result = actx.to_numpy(result)[()] + + return result def integral(dcoll: DiscretizationCollection, dd, vec):