From e8c7045f7fbc06eaf8548abac1610a16e8ecf574 Mon Sep 17 00:00:00 2001 From: xywei Date: Tue, 10 Dec 2019 10:09:32 -0600 Subject: [PATCH] Add preconditions of the broadcast to the ufunc results --- lappy/core/array.py | 3 ++- lappy/core/broadcast.py | 8 ++++---- lappy/core/ufuncs.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/lappy/core/array.py b/lappy/core/array.py index 8a20567..3c9e957 100644 --- a/lappy/core/array.py +++ b/lappy/core/array.py @@ -644,7 +644,8 @@ class Array(LazyObjectBase): raise ValueError( "cannot understand the return value " "of the precondition checker %s" % str(checker)) - except Exception: # noqa: W0703 + except Exception as e: # noqa: W0703 + print(e) failed_checks.append(checker) err_msgs = [] diff --git a/lappy/core/broadcast.py b/lappy/core/broadcast.py index ae0012f..11e5fe4 100644 --- a/lappy/core/broadcast.py +++ b/lappy/core/broadcast.py @@ -93,7 +93,7 @@ class BroadcastResult(object): else: # same name must have the same runtime values # (may have different expressions) - def check_name_valule_consistency(s, context): + def check_name_valule_consistency(context, s): s_val = evaluate(s.expr, context) another_s_val = evaluate( expr_map[var(s.name)], context) @@ -128,7 +128,7 @@ class BroadcastResult(object): # allow symbol == 1 or symbol == constant at runtime assert isinstance(si, str) - def check_broadcast_symbol_val(iaxis, si, context): + def check_broadcast_symbol_val(context, iaxis, si): si_val = evaluate(var(si), context) return si_val in (bshape_pre[iaxis], 1) @@ -150,7 +150,7 @@ class BroadcastResult(object): # allow symbol == 1 or symbol == constant at runtime assert isinstance(bshape_pre[iaxis], Expression) - def check_broadcast_symbol_val(iaxis, context): + def check_broadcast_symbol_val(context, iaxis): lhs_val = evaluate(bshape_pre[iaxis], context) return lhs_val in (bshape_pre[iaxis], 1) @@ -167,7 +167,7 @@ class BroadcastResult(object): assert isinstance(bshape_pre[iaxis], Expression) assert isinstance(si, str) - def check_broadcast_symbol_val(iaxis, si, context): + def check_broadcast_symbol_val(context, iaxis, si): lhs_val = evaluate(bshape_pre[iaxis], context) si_val = evaluate(var(si), context) return (lhs_val == si_val) or ( diff --git a/lappy/core/ufuncs.py b/lappy/core/ufuncs.py index 7fab3c2..1c8d322 100644 --- a/lappy/core/ufuncs.py +++ b/lappy/core/ufuncs.py @@ -298,7 +298,7 @@ class BinaryOperation(UFunc): 'bound_arguments': new_bound_arglist, 'intermediaries': new_interm, 'env': new_env, - 'preconditions': a.preconditions + b.preconditions, + 'preconditions': a.preconditions + b.preconditions + bres.preconditions, 'ndim': bres.ndim, 'shape': bres._shape_exprs, 'dtype': new_dtype, 'is_integral': all([ -- GitLab