From e8fd3587c224e25d8fa4de74469001c009686696 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Thu, 28 Oct 2021 10:01:26 -0500
Subject: [PATCH] use arraycontext.DeviceScalar type annotation

---
 doc/conf.py          |  5 +++++
 grudge/dt_utils.py   |  7 +++----
 grudge/reductions.py | 21 ++++++++++-----------
 3 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/doc/conf.py b/doc/conf.py
index 0cd5ba6d..31bc96d9 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -26,6 +26,11 @@ version = get_version()
 # The full version, including alpha/beta/rc tags.
 release = version
 
+autodoc_type_aliases = {
+    "DeviceScalar": "arraycontext.DeviceScalar",
+    "DeviceArray": "arraycontext.DeviceArray",
+    }
+
 intersphinx_mapping = {
     "https://docs.python.org/3/": None,
     "https://numpy.org/doc/stable/": None,
diff --git a/grudge/dt_utils.py b/grudge/dt_utils.py
index a97bdb39..1cc17281 100644
--- a/grudge/dt_utils.py
+++ b/grudge/dt_utils.py
@@ -44,9 +44,8 @@ THE SOFTWARE.
 
 
 import numpy as np
-from typing import Any
 
-from arraycontext import ArrayContext, thaw, freeze
+from arraycontext import ArrayContext, thaw, freeze, DeviceScalar
 from meshmode.transform_metadata import FirstAxisIsElementsTag
 
 from grudge.dof_desc import DD_VOLUME, DOFDesc, as_dofdesc
@@ -159,7 +158,7 @@ def dt_non_geometric_factors(
 
 @memoize_on_first_arg
 def h_max_from_volume(
-        dcoll: DiscretizationCollection, dim=None, dd=None) -> Any:
+        dcoll: DiscretizationCollection, dim=None, dd=None) -> "DeviceScalar":
     """Returns a (maximum) characteristic length based on the volume of the
     elements. This length may not be representative if the elements have very
     high aspect ratios.
@@ -190,7 +189,7 @@ def h_max_from_volume(
 
 @memoize_on_first_arg
 def h_min_from_volume(
-        dcoll: DiscretizationCollection, dim=None, dd=None) -> Any:
+        dcoll: DiscretizationCollection, dim=None, dd=None) -> "DeviceScalar":
     """Returns a (minimum) characteristic length based on the volume of the
     elements. This length may not be representative if the elements have very
     high aspect ratios.
diff --git a/grudge/reductions.py b/grudge/reductions.py
index a1848beb..ec301576 100644
--- a/grudge/reductions.py
+++ b/grudge/reductions.py
@@ -59,9 +59,8 @@ THE SOFTWARE.
 
 from numbers import Number
 from functools import reduce
-from typing import Any
 
-from arraycontext import make_loopy_program
+from arraycontext import make_loopy_program, DeviceScalar
 
 from grudge.discretization import DiscretizationCollection
 
@@ -76,7 +75,7 @@ import grudge.dof_desc as dof_desc
 
 # {{{ Nodal reductions
 
-def _norm(dcoll: DiscretizationCollection, vec, p, dd) -> Any:
+def _norm(dcoll: DiscretizationCollection, vec, p, dd) -> "DeviceScalar":
     if isinstance(vec, Number):
         return np.fabs(vec)
     if p == 2:
@@ -95,7 +94,7 @@ def _norm(dcoll: DiscretizationCollection, vec, p, dd) -> Any:
         raise NotImplementedError("Unsupported value of p")
 
 
-def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> Any:
+def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar":
     r"""Return the vector p-norm of a function represented
     by its vector of degrees of freedom *vec*.
 
@@ -131,7 +130,7 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> Any:
     return _norm(dcoll, vec, p, dd)
 
 
-def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the nodal sum of a vector of degrees of freedom *vec*.
 
     :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
@@ -151,7 +150,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Any:
         comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))
 
 
-def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the rank-local nodal sum of a vector of degrees of freedom *vec*.
 
     :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
@@ -168,7 +167,7 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
     return sum([actx.np.sum(grp_ary) for grp_ary in vec])
 
 
-def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the nodal minimum of a vector of degrees of freedom *vec*.
 
     :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
@@ -188,7 +187,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> Any:
         comm.allreduce(actx.to_numpy(nodal_min_loc(dcoll, dd, vec)), op=MPI.MIN))
 
 
-def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the rank-local nodal minimum of a vector of degrees
     of freedom *vec*.
 
@@ -208,7 +207,7 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
             vec, actx.from_numpy(np.array(np.inf)))
 
 
-def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the nodal maximum of a vector of degrees of freedom *vec*.
 
     :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
@@ -228,7 +227,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> Any:
         comm.allreduce(actx.to_numpy(nodal_max_loc(dcoll, dd, vec)), op=MPI.MAX))
 
 
-def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     r"""Return the rank-local nodal maximum of a vector of degrees
     of freedom *vec*.
 
@@ -248,7 +247,7 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> Any:
             vec, actx.from_numpy(np.array(-np.inf)))
 
 
-def integral(dcoll: DiscretizationCollection, dd, vec) -> Any:
+def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
     """Numerically integrates a function represented by a
     :class:`~meshmode.dof_array.DOFArray` of degrees of freedom.
 
-- 
GitLab