From 86fe823540cb404da626f7c63597ebc9d4375585 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 27 Sep 2021 01:32:30 -0500
Subject: [PATCH] arithmetic fixes to account for np.ndarray being a leaf array

---
 arraycontext/container/arithmetic.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 63f9327..663cdde 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -492,16 +492,17 @@ def with_container_arithmetic(
                     bcast_actx_ary_types = ()
 
                 gen(f"""
-                if {bool(outer_bcast_type_names)}:  # optimized away
-                    if isinstance(arg2,
-                                  {tup_str(outer_bcast_type_names
-                                           + bcast_actx_ary_types)}):
-                        return cls({bcast_same_cls_init_args})
                 if {numpy_pred("arg2")}:
                     result = np.empty_like(arg2, dtype=object)
                     for i in np.ndindex(arg2.shape):
                         result[i] = {op_str.format("arg1", "arg2[i]")}
                     return result
+
+                if {bool(outer_bcast_type_names)}:  # optimized away
+                    if isinstance(arg2,
+                                  {tup_str(outer_bcast_type_names
+                                           + bcast_actx_ary_types)}):
+                        return cls({bcast_same_cls_init_args})
                 return NotImplemented
                 """)
             gen(f"cls.__{dunder_name}__ = {fname}")
@@ -538,16 +539,16 @@ def with_container_arithmetic(
                     def {fname}(arg2, arg1):
                         # assert other.__cls__ is not cls
 
-                        if {bool(outer_bcast_type_names)}:  # optimized away
-                            if isinstance(arg1,
-                                          {tup_str(outer_bcast_type_names
-                                                   + bcast_actx_ary_types)}):
-                                return cls({bcast_init_args})
                         if {numpy_pred("arg1")}:
                             result = np.empty_like(arg1, dtype=object)
                             for i in np.ndindex(arg1.shape):
                                 result[i] = {op_str.format("arg1[i]", "arg2")}
                             return result
+                        if {bool(outer_bcast_type_names)}:  # optimized away
+                            if isinstance(arg1,
+                                          {tup_str(outer_bcast_type_names
+                                                   + bcast_actx_ary_types)}):
+                                return cls({bcast_init_args})
                         return NotImplemented
 
                     cls.__r{dunder_name}__ = {fname}""")
-- 
GitLab