diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index fb1d58d1ea830f6d3a609efd86d5c969587838d5..d8cc62d19bef2200acd718cef1182407152b9d72 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -25,7 +25,6 @@ THE SOFTWARE. from arraycontext import ( ArrayContainer, - map_array_container, with_container_arithmetic, dataclass_array_container ) @@ -37,6 +36,8 @@ from numbers import Number from pytools import memoize_on_first_arg from pytools.obj_array import obj_array_vectorize, make_obj_array +from grudge.discretization import DiscretizationCollection + from meshmode.dof_array import flatten, unflatten from meshmode.mesh import BTAG_PARTITION @@ -87,6 +88,10 @@ class TracePair: .. autoattribute:: ext .. autoattribute:: avg + .. automethod:: __getattr__ + .. automethod:: __getitem__ + .. automethod:: __len__ + .. note:: :class:`TracePair` is currently used both by the symbolic (deprecated) @@ -103,17 +108,26 @@ class TracePair: object.__setattr__(self, "exterior", exterior) def __getattr__(self, name): - return map_array_container( - lambda ary: getattr(ary, name), - self - ) + """Return a :class:`TracePair` associated with the attributes + of the array containers defining :attr:`int` and :attr:`ext`. + """ + return TracePair(self.dd, + interior=getattr(self.interior, name), + exterior=getattr(self.exterior, name)) def __getitem__(self, index): + """Return a :class:`TracePair` associated with the subarrays + of :attr:`int` and :attr:`ext`, denoted by `index`. + """ return TracePair(self.dd, interior=self.interior[index], exterior=self.exterior[index]) def __len__(self): + """Return the total number of arrays associated with the + :attr:`int` and :attr:`ext` restrictions of the :class:`TracePair`. + Note that both must be the same. + """ assert len(self.exterior) == len(self.interior) return len(self.exterior) @@ -146,11 +160,11 @@ class TracePair: # {{{ Boundary trace pairs -def bdry_trace_pair(dcoll, dd, interior, exterior): +def bdry_trace_pair( + dcoll: DiscretizationCollection, dd, interior, exterior) -> TracePair: """Returns a trace pair defined on the exterior boundary. Input arguments - are assumed to already be defined on the boundary. + are assumed to already be defined on the boundary denoted by *dd*. - :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one, which describes the boundary discretization. :arg interior: a :class:`~meshmode.dof_array.DOFArray` that contains data @@ -159,18 +173,18 @@ def bdry_trace_pair(dcoll, dd, interior, exterior): :arg exterior: a :class:`~meshmode.dof_array.DOFArray` that contains data that already lives on the boundary representing the exterior value to be used for the flux. - :returns: a :class:`grudge.trace_pair.TracePair` on the boundary. + :returns: a :class:`TracePair` on the boundary. """ return TracePair(dd, interior=interior, exterior=exterior) -def bv_trace_pair(dcoll, dd, interior, exterior): +def bv_trace_pair( + dcoll: DiscretizationCollection, dd, interior, exterior) -> TracePair: """Returns a trace pair defined on the exterior boundary. The interior argument is assumed to be defined on the volume discretization, and will - therefore be restricted to the boundary prior to creating a - :class:`grudge.trace_pair.TracePair`. + therefore be restricted to the boundary *dd* prior to creating a + :class:`TracePair`. - :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one, which describes the boundary discretization. :arg interior: a :class:`~meshmode.dof_array.DOFArray` that contains data @@ -180,7 +194,7 @@ def bv_trace_pair(dcoll, dd, interior, exterior): :arg exterior: a :class:`~meshmode.dof_array.DOFArray` that contains data that already lives on the boundary representing the exterior value to be used for the flux. - :returns: a :class:`grudge.trace_pair.TracePair` on the boundary. + :returns: a :class:`TracePair` on the boundary. """ from grudge.op import project @@ -192,15 +206,14 @@ def bv_trace_pair(dcoll, dd, interior, exterior): # {{{ Interior trace pairs -def _interior_trace_pair(dcoll, vec): - r"""Return a :class:`grudge.trace_pair.TracePair` for the interior faces of +def _interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair: + r"""Return a :class:`TracePair` for the interior faces of *dcoll* with a discretization tag specified by *discr_tag*. This does not include interior faces on different MPI ranks. - :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of :class:`~meshmode.dof_array.DOFArray`\ s. - :returns: a :class:`grudge.trace_pair.TracePair` object. + :returns: a :class:`TracePair` object. """ from grudge.op import project @@ -217,15 +230,14 @@ def _interior_trace_pair(dcoll, vec): return TracePair("int_faces", interior=i, exterior=e) -def interior_trace_pairs(dcoll, vec): - r"""Return a :class:`list` of :class:`grudge.trace_pair.TracePair` objects +def interior_trace_pairs(dcoll: DiscretizationCollection, vec) -> list: + r"""Return a :class:`list` of :class:`TracePair` objects defined on the interior faces of *dcoll* and any faces connected to a parallel boundary. - :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg vec: a :class:`~meshmode.dof_array.DOFArray` or object array of :class:`~meshmode.dof_array.DOFArray`\ s. - :returns: a :class:`list` of :class:`grudge.trace_pair.TracePair` objects. + :returns: a :class:`list` of :class:`TracePair` objects. """ return ( [_interior_trace_pair(dcoll, vec)] @@ -233,7 +245,7 @@ def interior_trace_pairs(dcoll, vec): ) -def interior_trace_pair(dcoll, vec): +def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair: from warnings import warn warn("`grudge.op.interior_trace_pair` is deprecated and will be dropped " "in version 2022.x. Use `grudge.trace_pair.interior_trace_pairs` " @@ -247,7 +259,7 @@ def interior_trace_pair(dcoll, vec): # {{{ Distributed-memory functionality @memoize_on_first_arg -def connected_ranks(dcoll): +def connected_ranks(dcoll: DiscretizationCollection): from meshmode.distributed import get_connected_partitions return get_connected_partitions(dcoll._volume_discr.mesh) @@ -255,7 +267,8 @@ def connected_ranks(dcoll): class _RankBoundaryCommunication: base_tag = 1273 - def __init__(self, dcoll, remote_rank, vol_field, tag=None): + def __init__(self, dcoll: DiscretizationCollection, + remote_rank, vol_field, tag=None): self.tag = self.base_tag if tag is not None: self.tag += tag @@ -297,7 +310,8 @@ class _RankBoundaryCommunication: exterior=swapped_remote_dof_array) -def _cross_rank_trace_pairs_scalar_field(dcoll, vec, tag=None): +def _cross_rank_trace_pairs_scalar_field( + dcoll: DiscretizationCollection, vec, tag=None) -> list: if isinstance(vec, Number): return [TracePair(BTAG_PARTITION(remote_rank), interior=vec, exterior=vec) for remote_rank in connected_ranks(dcoll)] @@ -307,7 +321,8 @@ def _cross_rank_trace_pairs_scalar_field(dcoll, vec, tag=None): return [rbcomm.finish() for rbcomm in rbcomms] -def cross_rank_trace_pairs(dcoll, ary, tag=None): +def cross_rank_trace_pairs( + dcoll: DiscretizationCollection, ary, tag=None) -> list: r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. For each partition boundary, the field data values in *ary* are @@ -316,16 +331,15 @@ def cross_rank_trace_pairs(dcoll, ary, tag=None): routine is agnostic to the underlying communication). For each face on each partition boundary, a - :class:`grudge.trace_pair.TracePair` is created with the locally, and + :class:`TracePair` is created with the locally, and remotely owned partition boundary face data as the `internal`, and `external` components, respectively. Each of the TracePair components are structured like *ary*. - :arg dcoll: a :class:`grudge.discretization.DiscretizationCollection`. :arg ary: a single :class:`~meshmode.dof_array.DOFArray`, or an object array of :class:`~meshmode.dof_array.DOFArray`\ s of arbitrary shape. - :returns: a :class:`list` of :class:`grudge.trace_pair.TracePair` objects. + :returns: a :class:`list` of :class:`TracePair` objects. """ if isinstance(ary, np.ndarray): oshape = ary.shape