diff --git a/pytato/array.py b/pytato/array.py index 86c3868f827307a45de037bab03a894114fb514c..d5d7b468120fdcc2e5f18c2e9017330786bf1d9e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1996,7 +1996,14 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, dtype = np.dtype(dtype) shape = normalize_shape(shape) - return IndexLambda(dtype.type(fill_value), shape, dtype, {}, + + if np.isnan(fill_value): + from pymbolic.primitives import NaN + fill_value = NaN(dtype.type) + else: + fill_value = dtype.type(fill_value) + + return IndexLambda(fill_value, shape, dtype, {}, tags=_get_default_tags(), axes=_get_default_axes(len(shape))) @@ -2310,7 +2317,8 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan # https://github.com/python/mypy/issues/3186 - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + return where(logical_or(isnan(x1), isnan(x2)), # type: ignore + common_dtype.type(np.NaN), where(greater(x1, x2), x1, x2)) else: return where(greater(x1, x2), x1, x2) @@ -2328,7 +2336,8 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan # https://github.com/python/mypy/issues/3186 - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + return where(logical_or(isnan(x1), isnan(x2)), # type: ignore + common_dtype.type(np.NaN), where(less(x1, x2), x1, x2)) else: return where(less(x1, x2), x1, x2) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index d0b215e713c56cc80a5250f2dd4c12f16185d75b..e29a26547021c48dca8982366af66799736f6a71 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -263,6 +263,7 @@ class Reduce(ExpressionBase): def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]: return (self.inner_expr, self.op, self.bounds) + init_arg_names = ("inner_expr", "op", "bounds") mapper_method = "map_reduce" # }}} diff --git a/pytato/utils.py b/pytato/utils.py index fe48c136824cab01a678b39d7888c7ab700da4b3..1eba387ff28a955427c80ed4836885860aba9f1f 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -158,7 +158,13 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, """ if isinstance(arr, SCALAR_CLASSES): - return arr + if np.isnan(arr): + # allowing NaNs to stay in our expression trees could potentially + # lead to spuriously unequal comparisons between expressions + from pymbolic.primitives import NaN + return NaN(np.array(arr).dtype.type) + else: + return arr assert isinstance(arr, Array) bindings[bnd_name] = arr