diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 25220e663c4614fbdf094d40e723bd5425f38102..e2a76feb1c1fb527829afb658418114236190eac 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -88,6 +88,7 @@ class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") + class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def _get_fake_numpy_linalg_namespace(self): return _PytatoFakeNumpyLinalgNamespace(self._array_context) @@ -118,6 +119,13 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): from meshmode.dof_array import obj_or_dof_array_vectorize_n_args return obj_or_dof_array_vectorize_n_args(pt.concatenate, arrays, axis) + def ones_like(self, ary): + def _ones_like(subary): + import pytato as pt + return pt.ones(subary.shape, subary.dtype) + + return self._new_like(ary, _ones_like) + def maximum(self, x, y): import pytato as pt from meshmode.dof_array import obj_or_dof_array_vectorize_n_args