diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 59b52149c228f18dd51a5fbe46232c3c4e967116..4d522e50714da5181ee816033378b797e9f2a934 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -37,7 +37,7 @@ from pytato.array import (Stack, Concatenate, IndexLambda, DataWrapper, AxisPermutation, Einsum, Reshape, Array, DictOfNamedArrays, IndexBase, DataInterface, NormalizedSlice, ShapeComponent, - IndexExpr, ArrayOrScalar) + IndexExpr, ArrayOrScalar, NamedArray) from immutables import Map from pytato.scalar_expr import SCALAR_CLASSES from pytato.utils import are_shape_components_equal @@ -499,6 +499,10 @@ class NumpyCodegenMapper(CachedMapper[ArrayOrNames]): return self._record_line_and_return_lhs(lhs, rhs) + def map_named_array(self, expr: NamedArray) -> str: + # type-ignore-reason: CachedMapper.rec's types are imprecise + return self.rec(expr.expr) # type: ignore[no-any-return] + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> str: lhs = self.vng("_pt_tmp")