From 6b8f1344d50800aef97bae987c9a882d1dd7c3b8 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 16 Mar 2022 19:19:08 -0500
Subject: [PATCH] Add flat_size_and_dtype

Tweak docstring

Co-authored-by: Alex Fikl <alexfikl@gmail.com>
---
 arraycontext/__init__.py            |  4 ++--
 arraycontext/container/traversal.py | 31 +++++++++++++++++++++++++++++
 arraycontext/context.py             |  2 +-
 test/test_arraycontext.py           | 23 +++++++++++++++------
 4 files changed, 51 insertions(+), 9 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index de51ee2..8206fb8 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -59,7 +59,7 @@ from .container.traversal import (
         rec_map_reduce_array_container,
         rec_multimap_reduce_array_container,
         thaw, freeze,
-        flatten, unflatten,
+        flatten, unflatten, flat_size_and_dtype,
         from_numpy, to_numpy,
         outer)
 
@@ -97,7 +97,7 @@ __all__ = (
         "map_reduce_array_container", "multimap_reduce_array_container",
         "rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
         "thaw", "freeze",
-        "flatten", "unflatten",
+        "flatten", "unflatten", "flat_size_and_dtype",
         "from_numpy", "to_numpy",
         "outer",
 
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 145ef63..3a11b81 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -27,6 +27,7 @@ Flattening and unflattening
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
 .. autofunction:: flatten
 .. autofunction:: unflatten
+.. autofunction:: flat_size_and_dtype
 
 Numpy conversion
 ~~~~~~~~~~~~~~~~
@@ -771,6 +772,36 @@ def unflatten(
 
     return result
 
+
+def flat_size_and_dtype(
+        ary: ArrayOrContainerT) -> Tuple[int, Optional[np.dtype[Any]]]:
+    """
+    :returns: a tuple ``(size, dtype)`` that would be the length and
+        :class:`numpy.dtype` of the one-dimensional array returned by
+        :func:`flatten`.
+    """
+    common_dtype = None
+
+    def _flat_size(subary: ArrayOrContainerT) -> int:
+        nonlocal common_dtype
+
+        try:
+            iterable = serialize_container(subary)
+        except NotAnArrayContainerError:
+            if common_dtype is None:
+                common_dtype = subary.dtype
+
+            if subary.dtype != common_dtype:
+                raise ValueError("arrays in container have different dtypes: "
+                        f"got {subary.dtype}, expected {common_dtype}")
+
+            return subary.size
+        else:
+            return sum(_flat_size(isubary) for _, isubary in iterable)
+
+    size = _flat_size(ary)
+    return size, common_dtype
+
 # }}}
 
 
diff --git a/arraycontext/context.py b/arraycontext/context.py
index a127b16..6c13a33 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -334,7 +334,7 @@ class ArrayContext(ABC):
         return self.tag(tagged, out_ary)
 
     @abstractmethod
-    def clone(self):
+    def clone(self) -> "ArrayContext":
         """If possible, return a version of *self* that is semantically
         equivalent (i.e. implements all array operations in the same way)
         but is a separate object. May return *self* if that is not possible.
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 54acefc..649f5fb 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1050,12 +1050,24 @@ def test_flatten_array_container(actx_factory, shapes):
     # }}}
 
 
+def _checked_flatten(ary, actx, leaf_class=None):
+    from arraycontext import flatten, flat_size_and_dtype
+    result = flatten(ary, actx, leaf_class=leaf_class)
+
+    if leaf_class is None:
+        size, dtype = flat_size_and_dtype(ary)
+        assert result.shape == (size,)
+        assert result.dtype == dtype
+
+    return result
+
+
 def test_flatten_array_container_failure(actx_factory):
     actx = actx_factory()
 
-    from arraycontext import flatten, unflatten
+    from arraycontext import unflatten
     ary = _get_test_containers(actx, shapes=512)[0]
-    flat_ary = flatten(ary, actx)
+    flat_ary = _checked_flatten(ary, actx)
 
     with pytest.raises(TypeError):
         # cannot unflatten from a numpy array
@@ -1073,19 +1085,18 @@ def test_flatten_array_container_failure(actx_factory):
 def test_flatten_with_leaf_class(actx_factory):
     actx = actx_factory()
 
-    from arraycontext import flatten
     arys = _get_test_containers(actx, shapes=512)
 
-    flat = flatten(arys[0], actx, leaf_class=DOFArray)
+    flat = _checked_flatten(arys[0], actx, leaf_class=DOFArray)
     assert isinstance(flat, actx.array_types)
     assert flat.shape == (arys[0].size,)
 
-    flat = flatten(arys[1], actx, leaf_class=DOFArray)
+    flat = _checked_flatten(arys[1], actx, leaf_class=DOFArray)
     assert isinstance(flat, np.ndarray) and flat.dtype == object
     assert all(isinstance(entry, actx.array_types) for entry in flat)
     assert all(entry.shape == (arys[0].size,) for entry in flat)
 
-    flat = flatten(arys[3], actx, leaf_class=DOFArray)
+    flat = _checked_flatten(arys[3], actx, leaf_class=DOFArray)
     assert isinstance(flat, MyContainer)
     assert isinstance(flat.mass, actx.array_types)
     assert flat.mass.shape == (arys[3].mass.size,)
-- 
GitLab