diff --git a/arraycontext/impl/__init__.py b/arraycontext/impl/__init__.py index ac0e47a3b24dcf4fc519c7120757fb1b4da29079..6df8258599acbee87f9b031e86ae365dbb532cbb 100644 --- a/arraycontext/impl/__init__.py +++ b/arraycontext/impl/__init__.py @@ -21,3 +21,12 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + + +def _is_meshmode_dofarray(x): + try: + from meshmode.dof_array import DOFArray + except ImportError: + return False + else: + return isinstance(x, DOFArray) diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index e50a58d4d393e9f9c8b765a3f1a0703e11077756..ec75e0fe12a6bda9342ea72df803817c7e82c782 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -168,25 +168,22 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): if ord is None: ord = 2 - try: - from meshmode.dof_array import DOFArray - except ImportError: - pass - else: - if isinstance(ary, DOFArray): - from warnings import warn - warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. " - "(DOFArrays use 2D arrays internally, and " - "actx.np.linalg.norm should compute matrix norms of those.) " - "This will stop working in 2022. " - "Use meshmode.dof_array.flat_norm instead.", - DeprecationWarning, stacklevel=2) - - import numpy.linalg as la - return la.norm( - [self.norm(_flatten_array(subary), ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) + from arraycontext.impl import _is_meshmode_dofarray + + if _is_meshmode_dofarray(ary): + from warnings import warn + warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. " + "(DOFArrays use 2D arrays internally, and " + "actx.np.linalg.norm should compute matrix norms of those.) " + "This will stop working in 2022. " + "Use meshmode.dof_array.flat_norm instead.", + DeprecationWarning, stacklevel=2) + + import numpy.linalg as la + return la.norm( + [self.norm(_flatten_array(subary), ord=ord) + for _, subary in serialize_container(ary)], + ord=ord) return super().norm(ary, ord) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index da9001875832acd00ee116c307aedd106a972704..e3d6e6e75e51f86dd1f1d87f52a98ba0d68e82da 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -121,7 +121,7 @@ class PytatoCompiledOperator: def __call__(self, *args): import pytato as pt import pyopencl.array as cla - from meshmode.dof_array import DOFArray + from arraycontext.impl import _is_meshmode_dofarray from pytools.obj_array import flat_obj_array updated_kwargs = {} @@ -149,6 +149,7 @@ class PytatoCompiledOperator: return input_dict def from_return_dict_to_obj_array(return_dict): + from meshmode.dof_array import DOFArray return flat_obj_array([DOFArray.from_list(self.actx, [self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"]) for j in range(self.output_spec[i])]) @@ -163,7 +164,7 @@ class PytatoCompiledOperator: updated_kwargs[arg_name] = cla.to_device(self.actx.queue, np.array(arg)) - elif isinstance(arg, np.ndarray) and all(isinstance(el, DOFArray) + elif isinstance(arg, np.ndarray) and all(_is_meshmode_dofarray(el) for el in arg): updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg)) else: @@ -270,6 +271,7 @@ class PytatoArrayContext(ArrayContext): def compile(self, f: Callable[[Any], Any], inputs_like: Tuple[Union[Number, np.array], ...]) -> Callable[..., Any]: from pytools.obj_array import flat_obj_array + from arraycontext.impl import _is_meshmode_dofarray from meshmode.dof_array import DOFArray import pytato as pt @@ -277,7 +279,7 @@ class PytatoArrayContext(ArrayContext): if isinstance(input_like, np.number): return pt.make_placeholder(input_like.dtype, f"_msh_inp_{pos}") - elif isinstance(input_like, np.ndarray) and all(isinstance(e, DOFArray) + elif isinstance(input_like, np.ndarray) and all(_is_meshmode_dofarray(e) for e in input_like): return flat_obj_array([DOFArray.from_list(self, [pt.make_placeholder(grp_ary.shape, @@ -303,7 +305,7 @@ class PytatoArrayContext(ArrayContext): for iel, el in enumerate(inputs_like)]) if not (isinstance(outputs, np.ndarray) - and all(isinstance(e, DOFArray) + and all(_is_meshmode_dofarray(e) for e in outputs)): raise TypeError("Can only pass in functions that return numpy" " array of DOFArrays.")