From 00e008bf7aea7a19327448f145bd7015cbff5d5a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 May 2023 09:11:13 -0700 Subject: [PATCH] copy axes and tags in zeros_like/full_like --- arraycontext/impl/pytato/fake_numpy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 4dad159..f2e3d2e 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -73,7 +73,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros(array.shape, array.dtype) + return self._array_context.zeros( + array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -83,7 +84,8 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype) + return pt.full(subary.shape, fill_value, subary.dtype).copy( + axes=subary.axes, tags=subary.tags) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value) -- GitLab