From e1aee5851518ea45d8809c9c5e31b12ef4afd6f4 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 24 Mar 2022 16:48:15 -0500 Subject: [PATCH] Remove np.nan from expression trees Comparison between NaNs always results in False. This would inhibit our pickling capabilities. --- pytato/array.py | 15 ++++++++++++--- pytato/scalar_expr.py | 1 + pytato/utils.py | 8 +++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 86c3868..d5d7b46 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1996,7 +1996,14 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, dtype = np.dtype(dtype) shape = normalize_shape(shape) - return IndexLambda(dtype.type(fill_value), shape, dtype, {}, + + if np.isnan(fill_value): + from pymbolic.primitives import NaN + fill_value = NaN(dtype.type) + else: + fill_value = dtype.type(fill_value) + + return IndexLambda(fill_value, shape, dtype, {}, tags=_get_default_tags(), axes=_get_default_axes(len(shape))) @@ -2310,7 +2317,8 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan # https://github.com/python/mypy/issues/3186 - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + return where(logical_or(isnan(x1), isnan(x2)), # type: ignore + common_dtype.type(np.NaN), where(greater(x1, x2), x1, x2)) else: return where(greater(x1, x2), x1, x2) @@ -2328,7 +2336,8 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan # https://github.com/python/mypy/issues/3186 - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + return where(logical_or(isnan(x1), isnan(x2)), # type: ignore + common_dtype.type(np.NaN), where(less(x1, x2), x1, x2)) else: return where(less(x1, x2), x1, x2) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index d0b215e..e29a265 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -263,6 +263,7 @@ class Reduce(ExpressionBase): def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]: return (self.inner_expr, self.op, self.bounds) + init_arg_names = ("inner_expr", "op", "bounds") mapper_method = "map_reduce" # }}} diff --git a/pytato/utils.py b/pytato/utils.py index fe48c13..1eba387 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -158,7 +158,13 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, """ if isinstance(arr, SCALAR_CLASSES): - return arr + if np.isnan(arr): + # allowing NaNs to stay in our expression trees could potentially + # lead to spuriously unequal comparisons between expressions + from pymbolic.primitives import NaN + return NaN(np.array(arr).dtype.type) + else: + return arr assert isinstance(arr, Array) bindings[bnd_name] = arr -- GitLab