diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 7f08c1e479e4b9a3d3e9530393d6450903973bcb..145ef63503ff105023534bced05aca763f96c23b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -583,7 +583,10 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: # {{{ flatten / unflatten -def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: +def flatten( + ary: ArrayOrContainerT, actx: ArrayContext, *, + leaf_class: Optional[type] = None, + ) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. @@ -591,11 +594,20 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: ``ravel`` and ``concatenate`` methods implemented. The order in which the individual leaf arrays appear in the final array is dependent on the order given by :func:`~arraycontext.serialize_container`. + + If *leaf_class* is given, then :func:`unflatten` will not be able to recover + the original *ary*. + + :arg leaf_class: an :class:`~arraycontext.ArrayContainer` class on which + the recursion is stopped (subclasses are not considered). If given, only + the entries of this type are flattened and the rest of the tree + structure is left as is. By default, the recursion is stopped when + a non-:class:`~arraycontext.ArrayContainer` is found, which results in + the whole input container *ary* being flattened. """ common_dtype = None - result: List[Any] = [] - def _flatten(subary: ArrayOrContainerT) -> None: + def _flatten(subary: ArrayOrContainerT) -> List[Any]: nonlocal common_dtype try: @@ -624,17 +636,40 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: "This functionality needs to be implemented by the " "array context.") from exc - result.append(flat_subary) + result = [flat_subary] else: + result = [] for _, isubary in iterable: - _flatten(isubary) + result.extend(_flatten(isubary)) + + return result + + def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any: + result = _flatten(subary) + + if len(result) == 1: + return result[0] + else: + return actx.np.concatenate(result) + + def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any: + if type(subary) is leaf_class: + return _flatten_without_leaf_class(subary) - _flatten(ary) + try: + iterable = serialize_container(subary) + except NotAnArrayContainerError: + return subary + else: + return deserialize_container(subary, [ + (key, _flatten_with_leaf_class(isubary)) + for key, isubary in iterable + ]) - if len(result) == 1: - return result[0] + if leaf_class is None: + return _flatten_without_leaf_class(ary) else: - return actx.np.concatenate(result) + return _flatten_with_leaf_class(ary) def unflatten( @@ -647,6 +682,8 @@ def unflatten( The order and sizes of each slice into *ary* are determined by the array container *template*. + :arg ary: a flat one-dimensional array with a size that matches the + number of entries in *template*. :arg strict: if *True* additional :class:`~numpy.dtype` and stride checking is performed on the unflattened array. Otherwise, these checks are skipped. diff --git a/run-pylint.sh b/run-pylint.sh index 6c7386669da645c3349bfb09355b80d0573f7273..c103b82edb447389387d16db7aec22a14524f5e9 100755 --- a/run-pylint.sh +++ b/run-pylint.sh @@ -20,4 +20,4 @@ if [[ -f .pylintrc-local.yml ]]; then PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml" fi -python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@" +PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ac7ebbf5798c232cc5ce8aa8bb4b382b98112d46..97f853e036f81b724b3e2e4130cf65acaf6a08b0 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -148,6 +148,10 @@ class DOFArray: # Why tuple([...])? https://stackoverflow.com/a/48592299 return (f"{template_instance_name}.array_context, tuple([{arg}])") + @property + def size(self): + return sum(ary.size for ary in self.data) + @property def real(self): return DOFArray(self.array_context, tuple([subary.real for subary in self])) @@ -1064,6 +1068,30 @@ def test_flatten_array_container_failure(actx_factory): # cannot unflatten partially unflatten(ary, flat_ary[:-1], actx) + +def test_flatten_with_leaf_class(actx_factory): + actx = actx_factory() + + from arraycontext import flatten + arys = _get_test_containers(actx, shapes=512) + + flat = flatten(arys[0], actx, leaf_class=DOFArray) + assert isinstance(flat, actx.array_types) + assert flat.shape == (arys[0].size,) + + flat = flatten(arys[1], actx, leaf_class=DOFArray) + assert isinstance(flat, np.ndarray) and flat.dtype == object + assert all(isinstance(entry, actx.array_types) for entry in flat) + assert all(entry.shape == (arys[0].size,) for entry in flat) + + flat = flatten(arys[3], actx, leaf_class=DOFArray) + assert isinstance(flat, MyContainer) + assert isinstance(flat.mass, actx.array_types) + assert flat.mass.shape == (arys[3].mass.size,) + assert isinstance(flat.enthalpy, actx.array_types) + assert flat.enthalpy.shape == (arys[3].enthalpy.size,) + assert all(isinstance(entry, actx.array_types) for entry in flat.momentum) + # }}}