diff --git a/pytato/raising.py b/pytato/raising.py index 98356532c2df1d984b37f11fa7ff55976e789549..25b5972a0705b900e810a795321f3ccf3d889e47 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -1,3 +1,4 @@ +import numpy as np import pymbolic.primitives as p from enum import Enum, auto, unique @@ -145,6 +146,11 @@ def _as_array_or_scalar(exprs: Sequence[ScalarExpression], and (binding_to_subscript[expr.aggregate.name] == expr)): result.append(bindings[expr.aggregate.name]) + elif isinstance(expr, p.NaN): + if expr.data_type: + result.append(expr.data_type(float("nan"))) + else: + result.append(np.nan) else: raise UnknownIndexLambdaExpr() diff --git a/test/test_pytato.py b/test/test_pytato.py index a67409890dccdd8aa294dddd628e289da809a171..4f040f821a74cd4d9f2c19cbe84f947dddc51f0a 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -749,6 +749,12 @@ def test_idx_lambda_to_hlo(): assert (index_lambda_to_high_level_op(pt.broadcast_to(a, (100, 10, 4))) == BroadcastOp(a)) + hlo = index_lambda_to_high_level_op(np.nan * a) + assert isinstance(hlo, BinaryOp) + assert hlo.binary_op == BinaryOpType.MULT + assert np.isnan(hlo.x1) + assert hlo.x2 is a + def test_deduplicate_data_wrappers(): from pytato.transform import CachedWalkMapper, deduplicate_data_wrappers