From 0ee426439892caac10547f2f8f1573f59892af00 Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Wed, 16 Jun 2021 19:11:00 -0500
Subject: [PATCH] Move array checking to local nodal reductions

---
 grudge/reductions.py | 24 ++++++++++++------------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/grudge/reductions.py b/grudge/reductions.py
index d9b0a6ba..026f76d3 100644
--- a/grudge/reductions.py
+++ b/grudge/reductions.py
@@ -137,10 +137,6 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the nodal sum.
     """
-    if isinstance(vec, np.ndarray):
-        return sum(nodal_sum(dcoll, dd, vec[idx])
-                   for idx in np.ndindex(vec.shape))
-
     comm = dcoll.mpi_communicator
     if comm is None:
         return nodal_sum_loc(dcoll, dd, vec)
@@ -159,6 +155,10 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the rank-local nodal sum.
     """
+    if isinstance(vec, np.ndarray):
+        return sum(nodal_sum_loc(dcoll, dd, vec[idx])
+                   for idx in np.ndindex(vec.shape))
+
     actx = vec.array_context
     return sum([actx.np.sum(grp_ary) for grp_ary in vec])
 
@@ -171,10 +171,6 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the nodal minimum.
     """
-    if isinstance(vec, np.ndarray):
-        return min(nodal_min(dcoll, dd, vec[idx])
-                   for idx in np.ndindex(vec.shape))
-
     comm = dcoll.mpi_communicator
     if comm is None:
         return nodal_min_loc(dcoll, dd, vec)
@@ -194,6 +190,10 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the rank-local nodal minimum.
     """
+    if isinstance(vec, np.ndarray):
+        return min(nodal_min_loc(dcoll, dd, vec[idx])
+                   for idx in np.ndindex(vec.shape))
+
     actx = vec.array_context
     return reduce(lambda acc, grp_ary: actx.np.minimum(acc, actx.np.min(grp_ary)),
                   vec, np.inf)
@@ -207,10 +207,6 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the nodal maximum.
     """
-    if isinstance(vec, np.ndarray):
-        return max(nodal_max(dcoll, dd, vec[idx])
-                   for idx in np.ndindex(vec.shape))
-
     comm = dcoll.mpi_communicator
     if comm is None:
         return nodal_max_loc(dcoll, dd, vec)
@@ -230,6 +226,10 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> float:
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`.
     :returns: a scalar denoting the rank-local nodal maximum.
     """
+    if isinstance(vec, np.ndarray):
+        return max(nodal_max_loc(dcoll, dd, vec[idx])
+                   for idx in np.ndindex(vec.shape))
+
     actx = vec.array_context
     return reduce(lambda acc, grp_ary: actx.np.maximum(acc, actx.np.max(grp_ary)),
                   vec, -np.inf)
-- 
GitLab