From 214e865b3c61a020034dc2e63014dc5bd4e33c71 Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Wed, 20 Oct 2021 12:53:32 -0500 Subject: [PATCH] attempt to fix force_device_scalars handling --- grudge/reductions.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/grudge/reductions.py b/grudge/reductions.py index c5a6fbea..8049358b 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -165,18 +165,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: actx = vec.array_context - if hasattr(actx, "_force_device_scalars"): - force_device_scalars = actx._force_device_scalars - else: - force_device_scalars = True - - def device_sum(x): - result = actx.np.sum(x) - if not force_device_scalars: - result = actx.from_numpy(np.array(result)) - return result - - return sum([device_sum(grp_ary) for grp_ary in vec]) + return sum([actx.np.sum(grp_ary) for grp_ary in vec]) def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any: @@ -228,9 +217,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: result = as_device_scalar(result) return result - return reduce( + result = reduce( lambda acc, grp_ary: actx.np.minimum(acc, device_min(grp_ary)), vec, as_device_scalar(np.inf)) + if not force_device_scalars: + result = actx.to_numpy(result)[()] + + return result def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any: @@ -282,9 +275,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any: result = as_device_scalar(result) return result - return reduce( + result = reduce( lambda acc, grp_ary: actx.np.maximum(acc, device_max(grp_ary)), vec, as_device_scalar(-np.inf)) + if not force_device_scalars: + result = actx.to_numpy(result)[()] + + return result def integral(dcoll: DiscretizationCollection, dd, vec): -- GitLab