diff --git a/pytato/utils.py b/pytato/utils.py index c475e1da1d51808527fe735c8e6b3a944de56971..fe48c136824cab01a678b39d7888c7ab700da4b3 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -32,7 +32,8 @@ from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeCompone DtypeOrScalar, ArrayOrScalar, BasicIndex, AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, - ConvertibleToIndexExpr, IndexExpr, NormalizedSlice) + ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, + _dtype_any) from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, SCALAR_CLASSES, INT_CLASSES, BoolT) from pytools import UniqueNameGenerator @@ -47,6 +48,7 @@ Helper routines .. autofunction:: are_shapes_equal .. autofunction:: get_shape_after_broadcasting .. autofunction:: dim_to_index_lambda_components +.. autofunction:: get_common_dtype_of_ary_or_scalars """ @@ -565,3 +567,21 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra if isinstance(idx, NormalizedSlice)]))) # }}} + + +def get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[ArrayOrScalar] + ) -> _dtype_any: + array_types: List[_dtype_any] = [] + scalar_types: List[_dtype_any] = [] + + for ary_or_scalar in ary_or_scalars: + if isinstance(ary_or_scalar, Array): + array_types.append(ary_or_scalar.dtype) + else: + assert isinstance(ary_or_scalar, SCALAR_CLASSES) + scalar_types.append(np.array(ary_or_scalar).dtype) + + return np.find_common_type(array_types=array_types, + scalar_types=scalar_types) + +# vim: fdm=marker