From 6b8f1344d50800aef97bae987c9a882d1dd7c3b8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 16 Mar 2022 19:19:08 -0500 Subject: [PATCH] Add flat_size_and_dtype Tweak docstring Co-authored-by: Alex Fikl <alexfikl@gmail.com> --- arraycontext/__init__.py | 4 ++-- arraycontext/container/traversal.py | 31 +++++++++++++++++++++++++++++ arraycontext/context.py | 2 +- test/test_arraycontext.py | 23 +++++++++++++++------ 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index de51ee2..8206fb8 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -59,7 +59,7 @@ from .container.traversal import ( rec_map_reduce_array_container, rec_multimap_reduce_array_container, thaw, freeze, - flatten, unflatten, + flatten, unflatten, flat_size_and_dtype, from_numpy, to_numpy, outer) @@ -97,7 +97,7 @@ __all__ = ( "map_reduce_array_container", "multimap_reduce_array_container", "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", - "flatten", "unflatten", + "flatten", "unflatten", "flat_size_and_dtype", "from_numpy", "to_numpy", "outer", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 145ef63..3a11b81 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -27,6 +27,7 @@ Flattening and unflattening ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: flatten .. autofunction:: unflatten +.. autofunction:: flat_size_and_dtype Numpy conversion ~~~~~~~~~~~~~~~~ @@ -771,6 +772,36 @@ def unflatten( return result + +def flat_size_and_dtype( + ary: ArrayOrContainerT) -> Tuple[int, Optional[np.dtype[Any]]]: + """ + :returns: a tuple ``(size, dtype)`` that would be the length and + :class:`numpy.dtype` of the one-dimensional array returned by + :func:`flatten`. + """ + common_dtype = None + + def _flat_size(subary: ArrayOrContainerT) -> int: + nonlocal common_dtype + + try: + iterable = serialize_container(subary) + except NotAnArrayContainerError: + if common_dtype is None: + common_dtype = subary.dtype + + if subary.dtype != common_dtype: + raise ValueError("arrays in container have different dtypes: " + f"got {subary.dtype}, expected {common_dtype}") + + return subary.size + else: + return sum(_flat_size(isubary) for _, isubary in iterable) + + size = _flat_size(ary) + return size, common_dtype + # }}} diff --git a/arraycontext/context.py b/arraycontext/context.py index a127b16..6c13a33 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -334,7 +334,7 @@ class ArrayContext(ABC): return self.tag(tagged, out_ary) @abstractmethod - def clone(self): + def clone(self) -> "ArrayContext": """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible. diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 54acefc..649f5fb 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1050,12 +1050,24 @@ def test_flatten_array_container(actx_factory, shapes): # }}} +def _checked_flatten(ary, actx, leaf_class=None): + from arraycontext import flatten, flat_size_and_dtype + result = flatten(ary, actx, leaf_class=leaf_class) + + if leaf_class is None: + size, dtype = flat_size_and_dtype(ary) + assert result.shape == (size,) + assert result.dtype == dtype + + return result + + def test_flatten_array_container_failure(actx_factory): actx = actx_factory() - from arraycontext import flatten, unflatten + from arraycontext import unflatten ary = _get_test_containers(actx, shapes=512)[0] - flat_ary = flatten(ary, actx) + flat_ary = _checked_flatten(ary, actx) with pytest.raises(TypeError): # cannot unflatten from a numpy array @@ -1073,19 +1085,18 @@ def test_flatten_array_container_failure(actx_factory): def test_flatten_with_leaf_class(actx_factory): actx = actx_factory() - from arraycontext import flatten arys = _get_test_containers(actx, shapes=512) - flat = flatten(arys[0], actx, leaf_class=DOFArray) + flat = _checked_flatten(arys[0], actx, leaf_class=DOFArray) assert isinstance(flat, actx.array_types) assert flat.shape == (arys[0].size,) - flat = flatten(arys[1], actx, leaf_class=DOFArray) + flat = _checked_flatten(arys[1], actx, leaf_class=DOFArray) assert isinstance(flat, np.ndarray) and flat.dtype == object assert all(isinstance(entry, actx.array_types) for entry in flat) assert all(entry.shape == (arys[0].size,) for entry in flat) - flat = flatten(arys[3], actx, leaf_class=DOFArray) + flat = _checked_flatten(arys[3], actx, leaf_class=DOFArray) assert isinstance(flat, MyContainer) assert isinstance(flat.mass, actx.array_types) assert flat.mass.shape == (arys[3].mass.size,) -- GitLab