diff --git a/pytato/utils.py b/pytato/utils.py index c8a5111f3e71e5d0f035de0fcef1a608929466ef..c4efab7b910cd516adfd3d8a94d1445ee38b50d3 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -25,7 +25,6 @@ THE SOFTWARE. import numpy as np import pymbolic.primitives as prim -from numbers import Number from typing import Tuple, List, Union, Callable, Any, Sequence, Dict, Optional from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar) @@ -142,7 +141,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - if isinstance(a1, Number) and isinstance(a2, Number): + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore