From c8269936186d8ebe60260867cb63b67ccfbd3948 Mon Sep 17 00:00:00 2001
From: Alex Fikl <alexfikl@gmail.com>
Date: Wed, 6 Oct 2021 19:14:57 -0500
Subject: [PATCH] unflatten: better check that template and ary sizes match
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/container/traversal.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index c951073..7fe30e3 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -570,7 +570,8 @@ def unflatten(
             iterable = serialize_container(template_subary)
         except TypeError:
             if (offset + template_subary.size) > ary.size:
-                raise ValueError("'template' and 'ary' sizes do not match")
+                raise ValueError("'template' and 'ary' sizes do not match: "
+                    "'template' is too large")
 
             if template_subary.dtype != ary.dtype:
                 raise ValueError("'template' dtype does not match 'ary': "
@@ -612,7 +613,12 @@ def unflatten(
                 "only one dimensional arrays can be unflattened: "
                 f"'ary' has shape {ary.shape}")
 
-    return _unflatten(template)
+    result = _unflatten(template)
+    if offset != ary.size:
+        raise ValueError("'template' and 'ary' sizes do not match: "
+            "'ary' is too large")
+
+    return result
 
 # }}}
 
-- 
GitLab