From ee563aaebe09a8bb9691c0938bda4db81b5c10ea Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Thu, 28 Apr 2022 12:32:27 -0500 Subject: [PATCH] index_lambda_to_hlo: support NaNs --- pytato/raising.py | 6 ++++++ test/test_pytato.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/pytato/raising.py b/pytato/raising.py index 9835653..25b5972 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -1,3 +1,4 @@ +import numpy as np import pymbolic.primitives as p from enum import Enum, auto, unique @@ -145,6 +146,11 @@ def _as_array_or_scalar(exprs: Sequence[ScalarExpression], and (binding_to_subscript[expr.aggregate.name] == expr)): result.append(bindings[expr.aggregate.name]) + elif isinstance(expr, p.NaN): + if expr.data_type: + result.append(expr.data_type(float("nan"))) + else: + result.append(np.nan) else: raise UnknownIndexLambdaExpr() diff --git a/test/test_pytato.py b/test/test_pytato.py index a674098..4f040f8 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -749,6 +749,12 @@ def test_idx_lambda_to_hlo(): assert (index_lambda_to_high_level_op(pt.broadcast_to(a, (100, 10, 4))) == BroadcastOp(a)) + hlo = index_lambda_to_high_level_op(np.nan * a) + assert isinstance(hlo, BinaryOp) + assert hlo.binary_op == BinaryOpType.MULT + assert np.isnan(hlo.x1) + assert hlo.x2 is a + def test_deduplicate_data_wrappers(): from pytato.transform import CachedWalkMapper, deduplicate_data_wrappers -- GitLab