From ee563aaebe09a8bb9691c0938bda4db81b5c10ea Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Thu, 28 Apr 2022 12:32:27 -0500
Subject: [PATCH] index_lambda_to_hlo: support NaNs

---
 pytato/raising.py   | 6 ++++++
 test/test_pytato.py | 6 ++++++
 2 files changed, 12 insertions(+)

diff --git a/pytato/raising.py b/pytato/raising.py
index 9835653..25b5972 100644
--- a/pytato/raising.py
+++ b/pytato/raising.py
@@ -1,3 +1,4 @@
+import numpy as np
 import pymbolic.primitives as p
 
 from enum import Enum, auto, unique
@@ -145,6 +146,11 @@ def _as_array_or_scalar(exprs: Sequence[ScalarExpression],
               and (binding_to_subscript[expr.aggregate.name]
                    == expr)):
             result.append(bindings[expr.aggregate.name])
+        elif isinstance(expr, p.NaN):
+            if expr.data_type:
+                result.append(expr.data_type(float("nan")))
+            else:
+                result.append(np.nan)
         else:
             raise UnknownIndexLambdaExpr()
 
diff --git a/test/test_pytato.py b/test/test_pytato.py
index a674098..4f040f8 100755
--- a/test/test_pytato.py
+++ b/test/test_pytato.py
@@ -749,6 +749,12 @@ def test_idx_lambda_to_hlo():
     assert (index_lambda_to_high_level_op(pt.broadcast_to(a, (100, 10, 4)))
             == BroadcastOp(a))
 
+    hlo = index_lambda_to_high_level_op(np.nan * a)
+    assert isinstance(hlo, BinaryOp)
+    assert hlo.binary_op == BinaryOpType.MULT
+    assert np.isnan(hlo.x1)
+    assert hlo.x2 is a
+
 
 def test_deduplicate_data_wrappers():
     from pytato.transform import CachedWalkMapper, deduplicate_data_wrappers
-- 
GitLab