From 926d381579f8005b31e575c2feab86fe13f73a11 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 19 Jul 2021 20:55:21 -0500
Subject: [PATCH] improves scalar detection

---
 pytato/utils.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/pytato/utils.py b/pytato/utils.py
index c8a5111..c4efab7 100644
--- a/pytato/utils.py
+++ b/pytato/utils.py
@@ -25,7 +25,6 @@ THE SOFTWARE.
 import numpy as np
 import pymbolic.primitives as prim
 
-from numbers import Number
 from typing import Tuple, List, Union, Callable, Any, Sequence, Dict, Optional
 from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent,
                           DtypeOrScalar, ArrayOrScalar)
@@ -142,7 +141,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
                         op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression],  # noqa:E501
                         get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]],  # noqa:E501
                         ) -> ArrayOrScalar:
-    if isinstance(a1, Number) and isinstance(a2, Number):
+    if np.isscalar(a1) and np.isscalar(a2):
         from pytato.scalar_expr import evaluate
         return evaluate(op(a1, a2))  # type: ignore
 
-- 
GitLab