diff --git a/grudge/reductions.py b/grudge/reductions.py index 83192e42c85bfb5aa041eb8c26d27cd78dd12f5b..2d23f7066d56c1742f9a670c9ac92d7b9822222c 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -161,7 +161,19 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> float: for idx in np.ndindex(vec.shape)) actx = vec.array_context - return sum([actx.np.sum(grp_ary) for grp_ary in vec]) + + if hasattr(actx, "_force_device_scalars"): + force_device_scalars = actx._force_device_scalars + else: + force_device_scalars = True + + def device_sum(x): + result = actx.np.sum(x) + if not force_device_scalars: + result = actx.from_numpy(np.array(result)) + return result + + return sum([device_sum(grp_ary) for grp_ary in vec]) def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> float: @@ -198,10 +210,24 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> float: for idx in np.ndindex(vec.shape)) actx = vec.array_context + + if hasattr(actx, "_force_device_scalars"): + force_device_scalars = actx._force_device_scalars + else: + force_device_scalars = True + + def as_device_scalar(x): + return actx.from_numpy(np.array(x)) + + def device_min(x): + result = actx.np.min(x) + if not force_device_scalars: + result = as_device_scalar(result) + return result + return reduce( - lambda acc, grp_ary: actx.np.minimum( - acc, actx.np.min(grp_ary)), - vec, actx.from_numpy(np.array(np.inf))) + lambda acc, grp_ary: actx.np.minimum(acc, device_min(grp_ary)), + vec, as_device_scalar(np.inf)) def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> float: @@ -239,10 +265,23 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> float: actx = vec.array_context + if hasattr(actx, "_force_device_scalars"): + force_device_scalars = actx._force_device_scalars + else: + force_device_scalars = True + + def as_device_scalar(x): + return actx.from_numpy(np.array(x)) + + def device_max(x): + result = actx.np.max(x) + if not force_device_scalars: + result = as_device_scalar(result) + return result + return reduce( - lambda acc, grp_ary: actx.np.maximum( - acc, actx.np.max(grp_ary)), - vec, actx.from_numpy(np.array(-np.inf))) + lambda acc, grp_ary: actx.np.maximum(acc, device_max(grp_ary)), + vec, as_device_scalar(-np.inf)) def integral(dcoll: DiscretizationCollection, dd, vec) -> float: