diff --git a/pytato/array.py b/pytato/array.py index 964ff351f29a07cbe16f93385725d913a67f9275..9fdf1de028ca90342145c646176482c58f5c39f5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1058,8 +1058,7 @@ class Einsum(Array): @cached_property def dtype(self) -> np.dtype[Any]: - return np.find_common_type(array_types=[arg.dtype for arg in self.args], - scalar_types=[]) + return np.result_type(*[arg.dtype for arg in self.args]) def with_tagged_reduction(self, redn_axis: Union[EinsumReductionAxis, str], @@ -2427,7 +2426,7 @@ def where(condition: ArrayOrScalar, x_dtype = x.dtype if isinstance(x, Array) else np.dtype(type(x)) y_dtype = y.dtype if isinstance(y, Array) else np.dtype(type(y)) - dtype = np.find_common_type([x_dtype, y_dtype], []) + dtype = np.promote_types(x_dtype, y_dtype) # }}} @@ -2467,7 +2466,8 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan return where(logical_or(isnan(x1), isnan(x2)), - common_dtype.type(np.NaN), + # I don't know why pylint thinks common_dtype is a tuple. + common_dtype.type(np.NaN), # pylint: disable=no-member where(greater(x1, x2), x1, x2)) else: return where(greater(x1, x2), x1, x2) @@ -2485,7 +2485,8 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan return where(logical_or(isnan(x1), isnan(x2)), - common_dtype.type(np.NaN), + # I don't know why pylint thinks common_dtype is a tuple. + common_dtype.type(np.NaN), # pylint: disable=no-member where(less(x1, x2), x1, x2)) else: return where(less(x1, x2), x1, x2) diff --git a/pytato/utils.py b/pytato/utils.py index 6b5811b7edf94da4e0a6408385d6eb67cace70a4..920d537be48816c1a09e98ab69274857a8883ed2 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -35,7 +35,7 @@ from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeCompone ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, _dtype_any, Einsum) from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, - SCALAR_CLASSES, INT_CLASSES, BoolT) + SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper from immutables import Map @@ -585,17 +585,16 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra def get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[ArrayOrScalar] ) -> _dtype_any: array_types: List[_dtype_any] = [] - scalar_types: List[_dtype_any] = [] + scalars: List[ScalarType] = [] for ary_or_scalar in ary_or_scalars: if isinstance(ary_or_scalar, Array): array_types.append(ary_or_scalar.dtype) else: assert isinstance(ary_or_scalar, SCALAR_CLASSES) - scalar_types.append(np.array(ary_or_scalar).dtype) + scalars.append(ary_or_scalar) - return np.find_common_type(array_types=array_types, - scalar_types=scalar_types) + return np.result_type(*array_types, *scalars) def get_einsum_subscript_str(expr: Einsum) -> str: