From 37a61820007880a4efeb63c25a8bab335d22dc10 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 7 Jul 2021 09:54:52 -0500
Subject: [PATCH] Improve error message for arithmetic on frozen array
 container

---
 arraycontext/container/arithmetic.py | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 7d5daa1..3ade1b3 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -274,6 +274,12 @@ def with_container_arithmetic(
             from numbers import Number
             import numpy as np
             from arraycontext import ArrayContainer
+
+            def _raise_if_actx_none(actx):
+                if actx is None:
+                    raise ValueError("array containers with frozen arrays "
+                        "cannot be operated upon")
+                return actx
             """)
         gen("")
 
@@ -375,8 +381,11 @@ def with_container_arithmetic(
                     gen(f"return cls({zip_init_args})")
 
                 if _bcast_actx_array_type:
-                    bcast_actx_ary_types: Tuple[str, ...] = (
-                        "*arg1.array_context.array_types",)
+                    if __debug__:
+                        bcast_actx_ary_types: Tuple[str, ...] = (
+                            "*_raise_if_actx_none(arg1.array_context).array_types",)
+                    else:
+                        bcast_actx_ary_types = ("*arg1.array_context.array_types",)
                 else:
                     bcast_actx_ary_types = ()
 
@@ -410,7 +419,11 @@ def with_container_arithmetic(
                         })
 
                 if _bcast_actx_array_type:
-                    bcast_actx_ary_types = ("*arg2.array_context.array_types",)
+                    if __debug__:
+                        bcast_actx_ary_types = (
+                            "*_raise_if_actx_none(arg2.array_context).array_types",)
+                    else:
+                        bcast_actx_ary_types = ("*arg2.array_context.array_types",)
                 else:
                     bcast_actx_ary_types = ()
 
-- 
GitLab