From 547899707670b7bbf516bf6f923661a8e0dd538f Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 27 Sep 2021 10:01:46 -0500
Subject: [PATCH] raise if flattened container does not have homogeneous dtypes

---
 arraycontext/container/traversal.py | 45 +++++++++++++++++++----------
 1 file changed, 29 insertions(+), 16 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 06a0726..08d9ba9 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -510,26 +510,36 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
     individual leaf arrays appear in the final array is dependent on the order
     given by :func:`~arraycontext.serialize_container`.
     """
-    @memoize_in(actx, (flatten, "ravel_prg"))
-    def _ravel_prg(shape: Tuple[int, ...]) -> Any:
-        raise NotImplementedError
+    common_dtype = None
+    result: List[Any] = []
 
     def _flatten(subary: ArrayOrContainerT) -> None:
+        nonlocal common_dtype
+
         try:
             iterable = serialize_container(subary)
         except TypeError:
+            if common_dtype is None:
+                common_dtype = subary.dtype
+
+            if subary.dtype != common_dtype:
+                raise ValueError("arrays in container have different dtypes: "
+                        f"got {subary.dtype}, expected {common_dtype}")
+
             try:
                 flat_subary = actx.np.ravel(subary, order="A")
-            except ValueError:
-                flat_subary = actx.call_loopy(
-                        _ravel_prg(subary.shape), ary=subary)
+            except ValueError as exc:
+                # NOTE: we can't do much if the array context fails to ravel,
+                # since it is the one responsible for the actual memory layout
+                raise NotImplementedError("'flatten' requires advanced reshaping "
+                        "functionality that is not implemented in the array "
+                        f"context '{type(actx).__name__}'") from exc
 
             result.append(flat_subary)
         else:
             for _, isubary in iterable:
                 _flatten(isubary)
 
-    result: List[Any] = []
     _flatten(ary)
 
     return actx.np.concatenate(result)
@@ -547,10 +557,6 @@ def unflatten(
     # NOTE: https://github.com/python/mypy/issues/7057
     offset = 0
 
-    @memoize_in(actx, (unflatten, "reshape_prg"))
-    def _reshape_prg(shape: Tuple[int, ...]) -> Any:
-        raise NotImplementedError
-
     def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
         nonlocal offset
 
@@ -562,14 +568,21 @@ def unflatten(
             if offset > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match")
 
+            if template_subary.dtype != ary.dtype:
+                raise ValueError("'template' dtype does not match 'ary': "
+                        f"got {template_subary.dtype}, expected {ary.dtype}")
+
             flat_subary = ary[offset - template_subary.size:offset]
             try:
                 subary = actx.np.reshape(flat_subary, template_subary.shape)
-            except ValueError:
-                subary = actx.call_loopy(
-                        _reshape_prg(template_subary.shape), ary=flat_subary)
-
-            return actx.np.astype(subary, template_subary.dtype)
+            except ValueError as exc:
+                # NOTE: we can't do much if the array context fails to reshape,
+                # since it is the one responsible for the actual memory layout
+                raise NotImplementedError("'unflatten' requires advanced reshaping "
+                        "functionality that is not implemented in the array "
+                        f"context '{type(actx).__name__}'") from exc
+
+            return subary
         else:
             return deserialize_container(template_subary, [
                 (key, _unflatten(isubary)) for key, isubary in iterable
-- 
GitLab