From 214e865b3c61a020034dc2e63014dc5bd4e33c71 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Wed, 20 Oct 2021 12:53:32 -0500
Subject: [PATCH] attempt to fix force_device_scalars handling

---
 grudge/reductions.py | 25 +++++++++++--------------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/grudge/reductions.py b/grudge/reductions.py
index c5a6fbea..8049358b 100644
--- a/grudge/reductions.py
+++ b/grudge/reductions.py
@@ -165,18 +165,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
 
     actx = vec.array_context
 
-    if hasattr(actx, "_force_device_scalars"):
-        force_device_scalars = actx._force_device_scalars
-    else:
-        force_device_scalars = True
-
-    def device_sum(x):
-        result = actx.np.sum(x)
-        if not force_device_scalars:
-            result = actx.from_numpy(np.array(result))
-        return result
-
-    return sum([device_sum(grp_ary) for grp_ary in vec])
+    return sum([actx.np.sum(grp_ary) for grp_ary in vec])
 
 
 def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any:
@@ -228,9 +217,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
             result = as_device_scalar(result)
         return result
 
-    return reduce(
+    result = reduce(
             lambda acc, grp_ary: actx.np.minimum(acc, device_min(grp_ary)),
             vec, as_device_scalar(np.inf))
+    if not force_device_scalars:
+        result = actx.to_numpy(result)[()]
+
+    return result
 
 
 def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any:
@@ -282,9 +275,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
             result = as_device_scalar(result)
         return result
 
-    return reduce(
+    result = reduce(
             lambda acc, grp_ary: actx.np.maximum(acc, device_max(grp_ary)),
             vec, as_device_scalar(-np.inf))
+    if not force_device_scalars:
+        result = actx.to_numpy(result)[()]
+
+    return result
 
 
 def integral(dcoll: DiscretizationCollection, dd, vec):
-- 
GitLab