From 28daaede6e1f3b4e3d7e31610137e92a8c08f866 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 1 Jun 2021 12:01:37 -0500 Subject: [PATCH] implement ones_like() --- arraycontext/impl/pytato.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 25220e6..e2a76fe 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -88,6 +88,7 @@ class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") + class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def _get_fake_numpy_linalg_namespace(self): return _PytatoFakeNumpyLinalgNamespace(self._array_context) @@ -118,6 +119,13 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): from meshmode.dof_array import obj_or_dof_array_vectorize_n_args return obj_or_dof_array_vectorize_n_args(pt.concatenate, arrays, axis) + def ones_like(self, ary): + def _ones_like(subary): + import pytato as pt + return pt.ones(subary.shape, subary.dtype) + + return self._new_like(ary, _ones_like) + def maximum(self, x, y): import pytato as pt from meshmode.dof_array import obj_or_dof_array_vectorize_n_args -- GitLab