diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 5291c9962fefa15480c13a7d4c7792bdec55b9c0..f459f229458a66b8e8e28caf5864ce4c229abc3c 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -30,6 +30,8 @@ THE SOFTWARE. from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace from arraycontext.context import ArrayContext +from arraycontext.container.traversal import (rec_multimap_array_container,) + rec_map_array_container) import numpy as np from typing import Any, Callable, Tuple, Union, Sequence from pytools.tag import Tag @@ -53,30 +55,23 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def exp(self, x): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize - return obj_or_dof_array_vectorize(pt.exp, x) + return rec_map_array_container(pt.exp, x) def sin(self, x): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize - return obj_or_dof_array_vectorize(pt.sin, x) + return rec_map_array_container(pt.sin, x) def reshape(self, a, newshape): import pytato as pt - - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.reshape, a, newshape) + return rec_multimap_array_container(pt.reshape, a, newshape) def transpose(self, a, axes=None): import pytato as pt - - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.transpose, a, axes) + return rec_multimap_array_container(pt.transpose, a, axes) def concatenate(self, arrays, axis=0): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.concatenate, arrays, axis) + return rec_multimap_array_container(pt.concatenate, arrays, axis) def ones_like(self, ary): def _ones_like(subary): @@ -87,18 +82,15 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def maximum(self, x, y): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.maximum, x, y) + return rec_multimap_array_container(pt.maximum, x, y) def minimum(self, x, y): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.minimum, x, y) + return rec_multimap_array_container(pt.minimum, x, y) def where(self, criterion, then, else_): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.where, criterion, then, else_) + return rec_multimap_array_container(pt.where, criterion, then, else_) def sum(self, a, dtype=None): import pytato as pt @@ -116,8 +108,7 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def stack(self, arrays, axis=0): import pytato as pt - from meshmode.dof_array import obj_or_dof_array_vectorize_n_args - return obj_or_dof_array_vectorize_n_args(pt.stack, arrays, axis) + return rec_multimap_array_container(pt.stack, arrays, axis) class PytatoCompiledOperator: diff --git a/requirements.txt b/requirements.txt index b1beba77d9370fa0d5833eb9ddc0e7a22f708e89..57b156775f32c204b586bfeaf4b7c5b33c96fc2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,3 @@ git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy@kernel_callables_v3-edit2.git#egg=loopy git+https://github.com/kaushikcfd/pytato.git@call_loopy#egg=pytato - -git+https://github.com/inducer/meshmode.git#egg=meshmode