diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py index 3136484c3290e02b9f87fdabb698d3e006555f6a..12ffe33fb19f83e3a38a10f52d434166ccd93e79 100644 --- a/meshmode/dof_array.py +++ b/meshmode/dof_array.py @@ -21,7 +21,7 @@ THE SOFTWARE. """ import numpy as np -from typing import Optional, TYPE_CHECKING +from typing import Optional, Iterable, TYPE_CHECKING from functools import partial from pytools import single_valued, memoize_in @@ -53,11 +53,15 @@ class DOFArray(np.ndarray): :class:`~meshmode.discretization.ElementGroupBase`. The arrays contained within a :class:`DOFArray` are expected to be logically two-dimensional, with shape - ``(nelements, nunit_dofs)``, using - :class:`~meshmode.discretization.ElementGroupBase.nelements` - and :class:`~meshmode.discretization.ElementGroupBase.nunit_dofs`. - It is derived from :class:`numpy.ndarray` with dtype object ("an object array"). - The entries in this array are further arrays managed by :attr:`array_context`. + ``(nelements, ndofs_per_element)``, where ``nelements`` is the same as + :attr:`~meshmode.discretization.ElementGroupBase.nelements` + of the associated group. + ``ndofs_per_element`` is typically, but not necessarily, the same as + :attr:`~meshmode.discretization.ElementGroupBase.nunit_dofs` + of the associated group. + This array is derived from :class:`numpy.ndarray` with dtype object ("an + object array"). The entries in this array are further arrays managed by + :attr:`array_context`. One main purpose of this class is to describe the data structure, i.e. when a :class:`DOFArray` occurs inside of further numpy object array, @@ -105,7 +109,7 @@ class DOFArray(np.ndarray): (one per :class:`~meshmode.discretization.ElementGroupBase`). :arg actx: If *None*, the arrays in *res_list* must be - :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed. + :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed. """ if not (actx is None or isinstance(actx, ArrayContext)): raise TypeError("actx must be of type ArrayContext") @@ -190,8 +194,9 @@ def flatten(ary: np.ndarray) -> np.ndarray: @memoize_in(actx, (flatten, "flatten_prg")) def prg(): return make_loopy_program( - "{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}", - "result[grp_start + iel*nunit_dofs + idof] = grp_ary[iel, idof]", + "{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}", + """result[grp_start + iel*ndofs_per_element + idof] \ + = grp_ary[iel, idof]""", name="flatten") result = actx.empty(group_starts[-1], dtype=ary.entry_dtype) @@ -202,7 +207,8 @@ def flatten(ary: np.ndarray) -> np.ndarray: return result -def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray: +def unflatten(actx: ArrayContext, discr: "_Discretization", ary, + ndofs_per_element_per_group: Optional[Iterable[int]] = None) -> np.ndarray: r"""Convert a 'flat' array returned by :func:`flatten` back to a :class:`DOFArray`. Vectorizes over object arrays of :class:`DOFArray`\ s. @@ -211,17 +217,26 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray: and ary.dtype.char == "O" and not isinstance(ary, DOFArray)): return obj_array_vectorize( - lambda subary: unflatten(actx, discr, subary), + lambda subary: unflatten( + actx, discr, subary, ndofs_per_element_per_group), ary) @memoize_in(actx, (unflatten, "unflatten_prg")) def prg(): return make_loopy_program( - "{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}", - "result[iel, idof] = ary[grp_start + iel*nunit_dofs + idof]", + "{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}", + "result[iel, idof] = ary[grp_start + iel*ndofs_per_element + idof]", name="unflatten") - group_sizes = [grp.ndofs for grp in discr.groups] + if ndofs_per_element_per_group is None: + ndofs_per_element_per_group = [ + grp.nunit_dofs for grp in discr.groups] + + group_sizes = [ + grp.nelements * ndofs_per_element + for grp, ndofs_per_element + in zip(discr.groups, ndofs_per_element_per_group)] + if ary.size != sum(group_sizes): raise ValueError("array has size %d, expected %d" % (ary.size, sum(group_sizes))) @@ -234,9 +249,12 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray: prg(), grp_start=grp_start, ary=ary, nelements=grp.nelements, - nunit_dofs=grp.nunit_dofs, + ndofs_per_element=ndofs_per_element, )["result"]) - for grp_start, grp in zip(group_starts, discr.groups)]) + for grp_start, grp, ndofs_per_element in zip( + group_starts, + discr.groups, + ndofs_per_element_per_group)]) def flat_norm(ary: DOFArray, ord=2):