Skip to content
Snippets Groups Projects
Commit 926d3815 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

improves scalar detection

parent b8dd7aef
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,6 @@ THE SOFTWARE.
import numpy as np
import pymbolic.primitives as prim
from numbers import Number
from typing import Tuple, List, Union, Callable, Any, Sequence, Dict, Optional
from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent,
DtypeOrScalar, ArrayOrScalar)
......@@ -142,7 +141,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501
get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501
) -> ArrayOrScalar:
if isinstance(a1, Number) and isinstance(a2, Number):
if np.isscalar(a1) and np.isscalar(a2):
from pytato.scalar_expr import evaluate
return evaluate(op(a1, a2)) # type: ignore
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment