From f15e8da294ee5fc6e3aee1eda2da189021e41ebb Mon Sep 17 00:00:00 2001
From: Thomas Gibson <gibsonthomas1120@hotmail.com>
Date: Sat, 15 May 2021 12:13:20 -0500
Subject: [PATCH] Clarify/add docs, clean up functions

---
 grudge/op.py | 160 ++++++++++++++++++++++++++++-----------------------
 1 file changed, 87 insertions(+), 73 deletions(-)

diff --git a/grudge/op.py b/grudge/op.py
index 13c64450..58aef33b 100644
--- a/grudge/op.py
+++ b/grudge/op.py
@@ -1,4 +1,6 @@
 """
+.. autoclass:: HasElementwiseMatvecTag
+
 .. autofunction:: project
 
 .. autofunction:: nodes
@@ -57,36 +59,44 @@ THE SOFTWARE.
 
 
 from numbers import Number
+
 from pytools import (
     memoize_in,
     memoize_on_first_arg,
     keyed_memoize_in
 )
-from pytools.obj_array import obj_array_vectorize
+from pytools.obj_array import obj_array_vectorize, make_obj_array
+
 from meshmode.array_context import (
     FirstAxisIsElementsTag, make_loopy_program
 )
-from meshmode.dof_array import DOFArray
+from meshmode.dof_array import DOFArray, freeze, flatten, unflatten
+from meshmode.mesh import BTAG_ALL, BTAG_NONE, BTAG_PARTITION  # noqa
 
 import numpy as np
 import grudge.dof_desc as dof_desc
 
-from meshmode.mesh import BTAG_ALL, BTAG_NONE, BTAG_PARTITION  # noqa
-from meshmode.dof_array import freeze, flatten, unflatten
-
 from grudge.symbolic.primitives import TracePair
 
 
 # {{{ Kernel tags
 
 class HasElementwiseMatvecTag(FirstAxisIsElementsTag):
-    pass
+    """A tag that is applicable to kernel programs indicating that
+    an element-wise matrix product is being performed. This indicates
+    that the first index corresponds to element indices and suggests that
+    the implementation should set element indices as the outermost
+    loop extent.
+    """
 
 # }}}
 
 
 # {{{ Interpolation and projection
 
+# FIXME: Should reintroduce interp and make clear distinctions
+# between projection and interpolations.
+# Related issue: https://github.com/inducer/grudge/issues/38
 # def interp(dcoll, src, tgt, vec):
 #     from warnings import warn
 #     warn("using 'interp' is deprecated, use 'project' instead.",
@@ -133,9 +143,10 @@ def nodes(dcoll, dd=None):
     :returns: an object array of :class:`~meshmode.dof_array.DOFArray`\ s
     """
     if dd is None:
-        return dcoll._volume_discr.nodes()
-    else:
-        return dcoll.discr_from_dd(dd).nodes()
+        dd = dof_desc.DD_VOLUME
+    dd = dof_desc.as_dofdesc(dd)
+
+    return dcoll.discr_from_dd(dd).nodes()
 
 
 @memoize_on_first_arg
@@ -172,9 +183,9 @@ def h_max_from_volume(dcoll, dim=None, dd=None):
     if dim is None:
         dim = dcoll.dim
 
-    ones_volm = dcoll._volume_discr.zeros(dcoll._setup_actx) + 1.0
-    return nodal_max(
-        dcoll, dd, elementwise_sum(dcoll, mass(dcoll, dd, ones_volm))
+    ones = dcoll.discr_from_dd(dd).zeros(dcoll._setup_actx) + 1.0
+    return nodal_maximum(
+        elementwise_sum(dcoll, mass(dcoll, dd, ones))
     ) ** (1.0 / dim)
 
 
@@ -198,9 +209,9 @@ def h_min_from_volume(dcoll, dim=None, dd=None):
     if dim is None:
         dim = dcoll.dim
 
-    ones_volm = dcoll._volume_discr.zeros(dcoll._setup_actx) + 1.0
-    return nodal_min(
-        dcoll, dd, elementwise_sum(dcoll, mass(dcoll, dd, ones_volm))
+    ones = dcoll.discr_from_dd(dd).zeros(dcoll._setup_actx) + 1.0
+    return nodal_minimum(
+        elementwise_sum(dcoll, mass(dcoll, dd, ones))
     ) ** (1.0 / dim)
 
 # }}}
@@ -225,8 +236,7 @@ def reference_derivative_matrices(actx, element_group):
 
 
 def _compute_local_gradient(dcoll, vec, xyz_axis):
-    from grudge.geometry import \
-        inverse_surface_metric_derivative
+    from grudge.geometry import inverse_surface_metric_derivative
 
     discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME)
     actx = vec.array_context
@@ -258,11 +268,8 @@ def local_grad(dcoll, vec):
     :arg vec: a :class:`~meshmode.dof_array.DOFArray`
     :returns: an object array of :class:`~meshmode.dof_array.DOFArray`\ s
     """
-    from pytools.obj_array import make_obj_array
-    return make_obj_array(
-        [_compute_local_gradient(dcoll, vec, xyz_axis)
-         for xyz_axis in range(dcoll.dim)]
-    )
+    return make_obj_array([_compute_local_gradient(dcoll, vec, xyz_axis)
+                           for xyz_axis in range(dcoll.dim)])
 
 
 def local_d_dx(dcoll, xyz_axis, vec):
@@ -284,18 +291,16 @@ def _div_helper(dcoll, diff_func, vecs):
     assert vecs.dtype == object
 
     if vecs.shape[-1] != dcoll.ambient_dim:
-        raise ValueError(
-            "last dimension of *vecs* argument must match "
-            "ambient dimension"
-        )
+        raise ValueError("last dimension of *vecs* argument must match "
+                "ambient dimension")
 
     if len(vecs.shape) == 1:
         return sum(diff_func(i, vec_i) for i, vec_i in enumerate(vecs))
     else:
         result = np.zeros(vecs.shape[:-1], dtype=object)
         for idx in np.ndindex(vecs.shape[:-1]):
-            result[idx] = \
-                sum(diff_func(i, vec_i) for i, vec_i in enumerate(vecs[idx]))
+            result[idx] = sum(
+                    diff_func(i, vec_i) for i, vec_i in enumerate(vecs[idx]))
         return result
 
 
@@ -311,11 +316,9 @@ def local_div(dcoll, vecs):
     :returns: a :class:`~meshmode.dof_array.DOFArray`
     """
 
-    return _div_helper(
-        dcoll,
-        lambda i, subvec: local_d_dx(dcoll, i, subvec),
-        vecs
-    )
+    return _div_helper(dcoll,
+            lambda i, subvec: local_d_dx(dcoll, i, subvec),
+            vecs)
 
 # }}}
 
@@ -424,8 +427,6 @@ def weak_local_grad(dcoll, *args):
     else:
         raise TypeError("invalid number of arguments")
 
-    from pytools.obj_array import make_obj_array
-
     return make_obj_array(
         [_apply_stiffness_transpose_operator(dcoll,
                                              dof_desc.DD_VOLUME,
@@ -454,9 +455,9 @@ def weak_local_d_dx(dcoll, *args):
     else:
         raise TypeError("invalid number of arguments")
 
-    return _apply_stiffness_transpose_operator(
-        dcoll, dof_desc.DD_VOLUME, dd, vec, xyz_axis
-    )
+    return _apply_stiffness_transpose_operator(dcoll,
+                                               dof_desc.DD_VOLUME,
+                                               dd, vec, xyz_axis)
 
 
 def weak_local_div(dcoll, *args):
@@ -482,11 +483,9 @@ def weak_local_div(dcoll, *args):
     else:
         raise TypeError("invalid number of arguments")
 
-    return _div_helper(
-        dcoll,
-        lambda i, subvec: weak_local_d_dx(dcoll, dd, i, subvec),
-        vecs
-    )
+    return _div_helper(dcoll,
+            lambda i, subvec: weak_local_d_dx(dcoll, dd, i, subvec),
+            vecs)
 
 # }}}
 
@@ -531,6 +530,13 @@ def reference_mass_matrix(actx, out_element_group, in_element_group):
 
 
 def _apply_mass_operator(dcoll, dd_out, dd_in, vec):
+    if isinstance(vec, np.ndarray):
+        return obj_array_vectorize(
+            lambda vi: _apply_mass_operator(dcoll,
+                                            dd_out,
+                                            dd_in, vi), vec
+        )
+
     from grudge.geometry import area_element
 
     in_discr = dcoll.discr_from_dd(dd_in)
@@ -567,11 +573,6 @@ def mass(dcoll, *args):
     else:
         raise TypeError("invalid number of arguments")
 
-    if isinstance(vec, np.ndarray):
-        return obj_array_vectorize(
-            lambda el: mass(dcoll, dd, el), vec
-        )
-
     return _apply_mass_operator(dcoll, dof_desc.DD_VOLUME, dd, vec)
 
 # }}}
@@ -600,6 +601,13 @@ def reference_inverse_mass_matrix(actx, element_group):
 
 
 def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec):
+    if isinstance(vec, np.ndarray):
+        return obj_array_vectorize(
+            lambda vi: _apply_inverse_mass_operator(dcoll,
+                                                    dd_out,
+                                                    dd_in, vi), vec
+        )
+
     from grudge.geometry import area_element
 
     if dd_out != dd_in:
@@ -653,10 +661,6 @@ def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec):
 
 
 def inverse_mass(dcoll, vec):
-    if isinstance(vec, np.ndarray):
-        return obj_array_vectorize(
-            lambda el: inverse_mass(dcoll, el), vec
-        )
 
     return _apply_inverse_mass_operator(
         dcoll, dof_desc.DD_VOLUME, dof_desc.DD_VOLUME, vec
@@ -668,15 +672,13 @@ def inverse_mass(dcoll, vec):
 # {{{ Face mass operator
 
 def reference_face_mass_matrix(actx, face_element_group, vol_element_group, dtype):
-
     @keyed_memoize_in(
         actx, reference_mass_matrix,
         lambda face_grp, vol_grp: (face_grp.discretization_key(),
                                    vol_grp.discretization_key()))
     def get_ref_face_mass_mat(face_grp, vol_grp):
         nfaces = vol_grp.mesh_el_group.nfaces
-        assert (face_grp.nelements
-                == nfaces * vol_grp.nelements)
+        assert face_grp.nelements == nfaces * vol_grp.nelements
 
         matrix = np.empty(
             (vol_grp.nunit_dofs,
@@ -719,8 +721,7 @@ def reference_face_mass_matrix(actx, face_element_group, vol_element_group, dtyp
             # If the group has a nodal basis and is unisolvent,
             # we use the basis on the face to compute the face mass matrix
             if (isinstance(face_grp, ElementGroupWithBasis)
-                    and face_grp.space.space_dim
-                    == face_grp.nunit_dofs):
+                    and face_grp.space.space_dim == face_grp.nunit_dofs):
 
                 face_basis = face_grp.basis_obj()
 
@@ -756,6 +757,11 @@ def reference_face_mass_matrix(actx, face_element_group, vol_element_group, dtyp
 
 
 def _apply_face_mass_operator(dcoll, dd, vec):
+    if isinstance(vec, np.ndarray):
+        return obj_array_vectorize(
+            lambda vi: _apply_face_mass_operator(dcoll, dd, vi), vec
+        )
+
     from grudge.geometry import area_element
 
     volm_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME)
@@ -821,10 +827,6 @@ def face_mass(dcoll, *args):
     else:
         raise TypeError("invalid number of arguments")
 
-    if isinstance(vec, np.ndarray):
-        return obj_array_vectorize(
-                lambda el: face_mass(dcoll, dd, el), vec)
-
     return _apply_face_mass_operator(dcoll, dd, vec)
 
 # }}}
@@ -841,8 +843,8 @@ def norm(dcoll, vec, p, dd=None):
         a :class:`~meshmode.dof_array.DOFArray`\ s,
         where the last axis of the array must have length
         matching the volume dimension.
-    :arg p: an integer denoting the order of the integral norm. For example,
-        `p` can be 2, or `numpy.inf`.
+    :arg p: an integer denoting the order of the integral norm. Currently,
+        only `p` values of 2 or `numpy.inf` are supported.
     :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one.
         Defaults to the base volume discretization if not provided.
     :returns: an integer denoting the norm.
@@ -851,18 +853,31 @@ def norm(dcoll, vec, p, dd=None):
         dd = dof_desc.DD_VOLUME
     dd = dof_desc.as_dofdesc(dd)
 
-    if isinstance(vec, np.ndarray):
-        if p == np.inf:
-            return np.max([norm(dcoll, vec_i, p, dd=dd) for vec_i in vec])
-        else:
-            return np.sum([norm(dcoll, vec_i, p, dd=dd) for vec_i in vec])
-
     if p == 2:
-        return np.sqrt(
-            nodal_summation(vec * _apply_mass_operator(dcoll, dd, dd, vec))
+        norm_squared = nodal_summation(
+            vec * _apply_mass_operator(dcoll, dd, dd, vec)
         )
+
+        if isinstance(norm_squared, np.ndarray):
+            if len(norm_squared.shape) != 1:
+                raise NotImplementedError("Can only take the norm of vectors")
+
+            norm_squared = sum(norm_squared)
+
+        return np.sqrt(norm_squared)
+
     elif p == np.inf:
-        return nodal_maximum(vec.array_context.np.fabs(vec))
+        actx = dcoll._setup_actx
+        result = nodal_maximum(actx.np.fabs(vec))
+
+        if isinstance(result, np.ndarray):
+            if len(result.shape) != 1:
+                raise NotImplementedError("Can only take the norm of vectors")
+
+            result = np.max(result)
+
+        return result
+
     else:
         raise NotImplementedError("Unsupported value of p")
 
@@ -1100,7 +1115,6 @@ def cross_rank_trace_pairs(dcoll, ary, tag=None):
     :class:`~meshmode.dof_array.DOFArray`, or an object
     array of ``DOFArray``\ s of arbitrary shape.
     """
-    from pytools.obj_array import make_obj_array
 
     if isinstance(ary, np.ndarray):
         oshape = ary.shape
-- 
GitLab