diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a221351556c22fcbd56a2eda819e2c03d4facad7..e6d79704cea80ccdc29896b19a19646c60231c3b 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -141,6 +141,11 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): from meshmode.dof_array import obj_or_dof_array_vectorize_n_args return obj_or_dof_array_vectorize_n_args(pt.minimum, x, y) + def where(self, criterion, then, else_): + import pytato as pt + from meshmode.dof_array import obj_or_dof_array_vectorize_n_args + return obj_or_dof_array_vectorize_n_args(pt.where, criterion, then, else_) + def sum(self, a, dtype=None): import pytato as pt if dtype not in [a.dtype, None]: