From ef786c92b715e64a06e74f9b97e83b71700b3b85 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Wed, 2 Jun 2021 15:06:24 -0500
Subject: [PATCH] _is_meshmode_dofarray

---
 arraycontext/impl/__init__.py |  9 +++++++++
 arraycontext/impl/pyopencl.py | 35 ++++++++++++++++-------------------
 arraycontext/impl/pytato.py   | 10 ++++++----
 3 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/arraycontext/impl/__init__.py b/arraycontext/impl/__init__.py
index ac0e47a..6df8258 100644
--- a/arraycontext/impl/__init__.py
+++ b/arraycontext/impl/__init__.py
@@ -21,3 +21,12 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
+
+
+def _is_meshmode_dofarray(x):
+    try:
+        from meshmode.dof_array import DOFArray
+    except ImportError:
+        return False
+    else:
+        return isinstance(x, DOFArray)
diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index e50a58d..ec75e0f 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -168,25 +168,22 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
         if ord is None:
             ord = 2
 
-        try:
-            from meshmode.dof_array import DOFArray
-        except ImportError:
-            pass
-        else:
-            if isinstance(ary, DOFArray):
-                from warnings import warn
-                warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
-                        "(DOFArrays use 2D arrays internally, and "
-                        "actx.np.linalg.norm should compute matrix norms of those.) "
-                        "This will stop working in 2022. "
-                        "Use meshmode.dof_array.flat_norm instead.",
-                        DeprecationWarning, stacklevel=2)
-
-                import numpy.linalg as la
-                return la.norm(
-                        [self.norm(_flatten_array(subary), ord=ord)
-                            for _, subary in serialize_container(ary)],
-                        ord=ord)
+        from arraycontext.impl import _is_meshmode_dofarray
+
+        if _is_meshmode_dofarray(ary):
+            from warnings import warn
+            warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
+                    "(DOFArrays use 2D arrays internally, and "
+                    "actx.np.linalg.norm should compute matrix norms of those.) "
+                    "This will stop working in 2022. "
+                    "Use meshmode.dof_array.flat_norm instead.",
+                    DeprecationWarning, stacklevel=2)
+
+            import numpy.linalg as la
+            return la.norm(
+                    [self.norm(_flatten_array(subary), ord=ord)
+                        for _, subary in serialize_container(ary)],
+                    ord=ord)
 
         return super().norm(ary, ord)
 
diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index da90018..e3d6e6e 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -121,7 +121,7 @@ class PytatoCompiledOperator:
     def __call__(self, *args):
         import pytato as pt
         import pyopencl.array as cla
-        from meshmode.dof_array import DOFArray
+        from arraycontext.impl import _is_meshmode_dofarray
         from pytools.obj_array import flat_obj_array
 
         updated_kwargs = {}
@@ -149,6 +149,7 @@ class PytatoCompiledOperator:
             return input_dict
 
         def from_return_dict_to_obj_array(return_dict):
+            from meshmode.dof_array import DOFArray
             return flat_obj_array([DOFArray.from_list(self.actx,
                 [self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"])
                  for j in range(self.output_spec[i])])
@@ -163,7 +164,7 @@ class PytatoCompiledOperator:
 
                 updated_kwargs[arg_name] = cla.to_device(self.actx.queue,
                         np.array(arg))
-            elif isinstance(arg, np.ndarray) and all(isinstance(el, DOFArray)
+            elif isinstance(arg, np.ndarray) and all(_is_meshmode_dofarray(el)
                                                      for el in arg):
                 updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg))
             else:
@@ -270,6 +271,7 @@ class PytatoArrayContext(ArrayContext):
     def compile(self, f: Callable[[Any], Any],
             inputs_like: Tuple[Union[Number, np.array], ...]) -> Callable[..., Any]:
         from pytools.obj_array import flat_obj_array
+        from arraycontext.impl import _is_meshmode_dofarray
         from meshmode.dof_array import DOFArray
         import pytato as pt
 
@@ -277,7 +279,7 @@ class PytatoArrayContext(ArrayContext):
             if isinstance(input_like, np.number):
                 return pt.make_placeholder(input_like.dtype,
                                            f"_msh_inp_{pos}")
-            elif isinstance(input_like, np.ndarray) and all(isinstance(e, DOFArray)
+            elif isinstance(input_like, np.ndarray) and all(_is_meshmode_dofarray(e)
                                                             for e in input_like):
                 return flat_obj_array([DOFArray.from_list(self,
                     [pt.make_placeholder(grp_ary.shape,
@@ -303,7 +305,7 @@ class PytatoArrayContext(ArrayContext):
                       for iel, el in enumerate(inputs_like)])
 
         if not (isinstance(outputs, np.ndarray)
-                and all(isinstance(e, DOFArray)
+                and all(_is_meshmode_dofarray(e)
                         for e in outputs)):
             raise TypeError("Can only pass in functions that return numpy"
                             " array of DOFArrays.")
-- 
GitLab