diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py index 3136484c3290e02b9f87fdabb698d3e006555f6a..12ffe33fb19f83e3a38a10f52d434166ccd93e79 100644 --- a/meshmode/dof_array.py +++ b/meshmode/dof_array.py @@ -21,7 +21,7 @@ THE SOFTWARE. """ import numpy as np -from typing import Optional, TYPE_CHECKING +from typing import Optional, Iterable, TYPE_CHECKING from functools import partial from pytools import single_valued, memoize_in @@ -53,11 +53,15 @@ class DOFArray(np.ndarray): :class:`~meshmode.discretization.ElementGroupBase`. The arrays contained within a :class:`DOFArray` are expected to be logically two-dimensional, with shape - ``(nelements, nunit_dofs)``, using - :class:`~meshmode.discretization.ElementGroupBase.nelements` - and :class:`~meshmode.discretization.ElementGroupBase.nunit_dofs`. - It is derived from :class:`numpy.ndarray` with dtype object ("an object array"). - The entries in this array are further arrays managed by :attr:`array_context`. + ``(nelements, ndofs_per_element)``, where ``nelements`` is the same as + :attr:`~meshmode.discretization.ElementGroupBase.nelements` + of the associated group. + ``ndofs_per_element`` is typically, but not necessarily, the same as + :attr:`~meshmode.discretization.ElementGroupBase.nunit_dofs` + of the associated group. + This array is derived from :class:`numpy.ndarray` with dtype object ("an + object array"). The entries in this array are further arrays managed by + :attr:`array_context`. One main purpose of this class is to describe the data structure, i.e. when a :class:`DOFArray` occurs inside of further numpy object array, @@ -105,7 +109,7 @@ class DOFArray(np.ndarray): (one per :class:`~meshmode.discretization.ElementGroupBase`). :arg actx: If *None*, the arrays in *res_list* must be - :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed. + :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed. """ if not (actx is None or isinstance(actx, ArrayContext)): raise TypeError("actx must be of type ArrayContext") @@ -190,8 +194,9 @@ def flatten(ary: np.ndarray) -> np.ndarray: @memoize_in(actx, (flatten, "flatten_prg")) def prg(): return make_loopy_program( - "{[iel,idof]: 0<=iel np.ndarray: return result -def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray: +def unflatten(actx: ArrayContext, discr: "_Discretization", ary, + ndofs_per_element_per_group: Optional[Iterable[int]] = None) -> np.ndarray: r"""Convert a 'flat' array returned by :func:`flatten` back to a :class:`DOFArray`. Vectorizes over object arrays of :class:`DOFArray`\ s. @@ -211,17 +217,26 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray: and ary.dtype.char == "O" and not isinstance(ary, DOFArray)): return obj_array_vectorize( - lambda subary: unflatten(actx, discr, subary), + lambda subary: unflatten( + actx, discr, subary, ndofs_per_element_per_group), ary) @memoize_in(actx, (unflatten, "unflatten_prg")) def prg(): return make_loopy_program( - "{[iel,idof]: 0<=iel np.ndarray: prg(), grp_start=grp_start, ary=ary, nelements=grp.nelements, - nunit_dofs=grp.nunit_dofs, + ndofs_per_element=ndofs_per_element, )["result"]) - for grp_start, grp in zip(group_starts, discr.groups)]) + for grp_start, grp, ndofs_per_element in zip( + group_starts, + discr.groups, + ndofs_per_element_per_group)]) def flat_norm(ary: DOFArray, ord=2):