diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index e3d6e6e75e51f86dd1f1d87f52a98ba0d68e82da..ace684b5b8b3a0e7d36dfc7510239a2c4e47d296 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -53,6 +53,18 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def ns(self): return self._array_context.ns + def __getattr__(self, name): + + pt_funcs = ["abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "exp"] + if name in pt_funcs: + import pytato as pt + from functools import partial + return partial(rec_map_array_container, getattr(pt, name)) + + return super().__getattr__(name) + def exp(self, x): import pytato as pt return rec_map_array_container(pt.exp, x)