diff --git a/pytato/array.py b/pytato/array.py index 07b1dfee4ffc97be8e0787a5ce96d7434b754760..860007025980e46b3a05e69799a7f207197dd4b0 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1983,15 +1983,19 @@ def make_data_wrapper(data: DataInterface, # {{{ full def full(shape: ConvertibleToShape, fill_value: ScalarType, - dtype: Any, order: str = "C") -> Array: + dtype: Any = None, order: str = "C") -> Array: """ Returns an array of shape *shape* with all entries equal to *fill_value*. """ if order != "C": raise ValueError("Only C-ordered arrays supported for now.") + if dtype is None: + dtype = np.array(fill_value).dtype + else: + dtype = np.dtype(dtype) + shape = normalize_shape(shape) - dtype = np.dtype(dtype) return IndexLambda(dtype.type(fill_value), shape, dtype, {}, tags=_get_default_tags(), axes=_get_default_axes(len(shape))) @@ -2299,10 +2303,17 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: Returns the elementwise maximum of *x1*, *x2*. *x1*, *x2* being array-like objects that could be broadcasted together. NaNs are propagated. """ - # https://github.com/python/mypy/issues/3186 - from pytato.cmath import isnan - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore - where(greater(x1, x2), x1, x2)) + from pytato.utils import get_common_dtype_of_ary_or_scalars + common_dtype = get_common_dtype_of_ary_or_scalars([x1, x2]) + + if (np.issubdtype(common_dtype, np.floating) + or np.issubdtype(common_dtype, np.complexfloating)): + from pytato.cmath import isnan + # https://github.com/python/mypy/issues/3186 + return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + where(greater(x1, x2), x1, x2)) + else: + return where(greater(x1, x2), x1, x2) def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: @@ -2310,10 +2321,17 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: Returns the elementwise minimum of *x1*, *x2*. *x1*, *x2* being array-like objects that could be broadcasted together. NaNs are propagated. """ - # https://github.com/python/mypy/issues/3186 - from pytato.cmath import isnan - return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore - where(less(x1, x2), x1, x2)) + from pytato.utils import get_common_dtype_of_ary_or_scalars + common_dtype = get_common_dtype_of_ary_or_scalars([x1, x2]) + + if (np.issubdtype(common_dtype, np.floating) + or np.issubdtype(common_dtype, np.complexfloating)): + from pytato.cmath import isnan + # https://github.com/python/mypy/issues/3186 + return where(logical_or(isnan(x1), isnan(x2)), np.NaN, # type: ignore + where(less(x1, x2), x1, x2)) + else: + return where(less(x1, x2), x1, x2) # }}}