From 86fe823540cb404da626f7c63597ebc9d4375585 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 27 Sep 2021 01:32:30 -0500 Subject: [PATCH] arithmetic fixes to account for np.ndarray being a leaf array --- arraycontext/container/arithmetic.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 63f9327..663cdde 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -492,16 +492,17 @@ def with_container_arithmetic( bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -538,16 +539,16 @@ def with_container_arithmetic( def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") -- GitLab