Skip to content
Snippets Groups Projects
Commit d9b409e0 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'dofarray-variable-last-dim-sizes' into 'array-context'

DOFArray: Permit variable final dimension sizes

See merge request inducer/meshmode!90
parents 88aa9028 372927fb
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ THE SOFTWARE.
"""
import numpy as np
from typing import Optional, TYPE_CHECKING
from typing import Optional, Iterable, TYPE_CHECKING
from functools import partial
from pytools import single_valued, memoize_in
......@@ -53,11 +53,15 @@ class DOFArray(np.ndarray):
:class:`~meshmode.discretization.ElementGroupBase`.
The arrays contained within a :class:`DOFArray`
are expected to be logically two-dimensional, with shape
``(nelements, nunit_dofs)``, using
:class:`~meshmode.discretization.ElementGroupBase.nelements`
and :class:`~meshmode.discretization.ElementGroupBase.nunit_dofs`.
It is derived from :class:`numpy.ndarray` with dtype object ("an object array").
The entries in this array are further arrays managed by :attr:`array_context`.
``(nelements, ndofs_per_element)``, where ``nelements`` is the same as
:attr:`~meshmode.discretization.ElementGroupBase.nelements`
of the associated group.
``ndofs_per_element`` is typically, but not necessarily, the same as
:attr:`~meshmode.discretization.ElementGroupBase.nunit_dofs`
of the associated group.
This array is derived from :class:`numpy.ndarray` with dtype object ("an
object array"). The entries in this array are further arrays managed by
:attr:`array_context`.
One main purpose of this class is to describe the data structure,
i.e. when a :class:`DOFArray` occurs inside of further numpy object array,
......@@ -105,7 +109,7 @@ class DOFArray(np.ndarray):
(one per :class:`~meshmode.discretization.ElementGroupBase`).
:arg actx: If *None*, the arrays in *res_list* must be
:meth:`~meshmode.array_context.ArrayContext.thaw`\ ed.
:meth:`~meshmode.array_context.ArrayContext.thaw`\ ed.
"""
if not (actx is None or isinstance(actx, ArrayContext)):
raise TypeError("actx must be of type ArrayContext")
......@@ -190,8 +194,9 @@ def flatten(ary: np.ndarray) -> np.ndarray:
@memoize_in(actx, (flatten, "flatten_prg"))
def prg():
return make_loopy_program(
"{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}",
"result[grp_start + iel*nunit_dofs + idof] = grp_ary[iel, idof]",
"{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}",
"""result[grp_start + iel*ndofs_per_element + idof] \
= grp_ary[iel, idof]""",
name="flatten")
result = actx.empty(group_starts[-1], dtype=ary.entry_dtype)
......@@ -202,7 +207,8 @@ def flatten(ary: np.ndarray) -> np.ndarray:
return result
def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
def unflatten(actx: ArrayContext, discr: "_Discretization", ary,
ndofs_per_element_per_group: Optional[Iterable[int]] = None) -> np.ndarray:
r"""Convert a 'flat' array returned by :func:`flatten` back to a :class:`DOFArray`.
Vectorizes over object arrays of :class:`DOFArray`\ s.
......@@ -211,17 +217,26 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
and ary.dtype.char == "O"
and not isinstance(ary, DOFArray)):
return obj_array_vectorize(
lambda subary: unflatten(actx, discr, subary),
lambda subary: unflatten(
actx, discr, subary, ndofs_per_element_per_group),
ary)
@memoize_in(actx, (unflatten, "unflatten_prg"))
def prg():
return make_loopy_program(
"{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}",
"result[iel, idof] = ary[grp_start + iel*nunit_dofs + idof]",
"{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}",
"result[iel, idof] = ary[grp_start + iel*ndofs_per_element + idof]",
name="unflatten")
group_sizes = [grp.ndofs for grp in discr.groups]
if ndofs_per_element_per_group is None:
ndofs_per_element_per_group = [
grp.nunit_dofs for grp in discr.groups]
group_sizes = [
grp.nelements * ndofs_per_element
for grp, ndofs_per_element
in zip(discr.groups, ndofs_per_element_per_group)]
if ary.size != sum(group_sizes):
raise ValueError("array has size %d, expected %d"
% (ary.size, sum(group_sizes)))
......@@ -234,9 +249,12 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
prg(),
grp_start=grp_start, ary=ary,
nelements=grp.nelements,
nunit_dofs=grp.nunit_dofs,
ndofs_per_element=ndofs_per_element,
)["result"])
for grp_start, grp in zip(group_starts, discr.groups)])
for grp_start, grp, ndofs_per_element in zip(
group_starts,
discr.groups,
ndofs_per_element_per_group)])
def flat_norm(ary: DOFArray, ord=2):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment