From e7c35b094a69ffa33a6720724e256e222611bdb0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 13:19:55 -0500 Subject: [PATCH] support generating JAX code with NamedArrays in expr graph --- pytato/target/python/numpy_like.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 59b5214..4d522e5 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -37,7 +37,7 @@ from pytato.array import (Stack, Concatenate, IndexLambda, DataWrapper, AxisPermutation, Einsum, Reshape, Array, DictOfNamedArrays, IndexBase, DataInterface, NormalizedSlice, ShapeComponent, - IndexExpr, ArrayOrScalar) + IndexExpr, ArrayOrScalar, NamedArray) from immutables import Map from pytato.scalar_expr import SCALAR_CLASSES from pytato.utils import are_shape_components_equal @@ -499,6 +499,10 @@ class NumpyCodegenMapper(CachedMapper[ArrayOrNames]): return self._record_line_and_return_lhs(lhs, rhs) + def map_named_array(self, expr: NamedArray) -> str: + # type-ignore-reason: CachedMapper.rec's types are imprecise + return self.rec(expr.expr) # type: ignore[no-any-return] + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str: lhs = self.vng("_pt_tmp") -- GitLab