From 926d381579f8005b31e575c2feab86fe13f73a11 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 19 Jul 2021 20:55:21 -0500 Subject: [PATCH] improves scalar detection --- pytato/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index c8a5111..c4efab7 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -25,7 +25,6 @@ THE SOFTWARE. import numpy as np import pymbolic.primitives as prim -from numbers import Number from typing import Tuple, List, Union, Callable, Any, Sequence, Dict, Optional from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar) @@ -142,7 +141,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - if isinstance(a1, Number) and isinstance(a2, Number): + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore -- GitLab