From 53e3857ed70f68150862859419162178486f1ffa Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Fri, 18 Jun 2021 09:46:50 -0500
Subject: [PATCH] import pyopencl.array once at the top

---
 arraycontext/impl/pyopencl.py | 43 ++++++++++++++++-------------------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index 80f0bb4..0b96ed8 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -43,6 +43,12 @@ from arraycontext.container.traversal import (rec_multimap_array_container,
 from arraycontext.container import serialize_container, is_array_container
 from arraycontext.context import ArrayContext
 
+try:
+    import pyopencl as cl  # noqa: F401
+    import pyopencl.array as cl_array
+except ImportError:
+    pass
+
 
 # {{{ fake numpy
 
@@ -88,20 +94,16 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return self._new_like(ary, _ones_like)
 
     def maximum(self, x, y):
-        import pyopencl.array as cl_array
         return rec_multimap_array_container(
                 partial(cl_array.maximum, queue=self._array_context.queue),
                 x, y)
 
     def minimum(self, x, y):
-        import pyopencl.array as cl_array
         return rec_multimap_array_container(
                 partial(cl_array.minimum, queue=self._array_context.queue),
                 x, y)
 
     def where(self, criterion, then, else_):
-        import pyopencl.array as cl_array
-
         def where_inner(inner_crit, inner_then, inner_else):
             if isinstance(inner_crit, bool):
                 return inner_then if inner_crit else inner_else
@@ -111,32 +113,26 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_multimap_array_container(where_inner, criterion, then, else_)
 
     def sum(self, a, dtype=None):
-        import pyopencl.array as cl_array
         return cl_array.sum(
                 a, dtype=dtype, queue=self._array_context.queue).get()[()]
 
     def min(self, a):
-        import pyopencl.array as cl_array
         return cl_array.min(a, queue=self._array_context.queue).get()[()]
 
     def max(self, a):
-        import pyopencl.array as cl_array
         return cl_array.max(a, queue=self._array_context.queue).get()[()]
 
     def stack(self, arrays, axis=0):
-        import pyopencl.array as cla
         return rec_multimap_array_container(
-                lambda *args: cla.stack(arrays=args, axis=axis,
+                lambda *args: cl_array.stack(arrays=args, axis=axis,
                     queue=self._array_context.queue),
                 *arrays)
 
     def reshape(self, a, newshape):
-        import pyopencl.array as cla
-        return cla.reshape(a, newshape)
+        return cl_array.reshape(a, newshape)
 
     def concatenate(self, arrays, axis=0):
-        import pyopencl.array as cla
-        return cla.concatenate(
+        return cl_array.concatenate(
             arrays, axis,
             self._array_context.queue,
             self._array_context.allocator
@@ -170,8 +166,7 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
 # {{{ fake np.linalg
 
 def _flatten_array(ary):
-    import pyopencl.array as cl
-    assert isinstance(ary, cl.Array)
+    assert isinstance(ary, cl_array.Array)
 
     if ary.size == 0:
         # Work around https://github.com/inducer/pyopencl/pull/402
@@ -189,12 +184,11 @@ def _flatten_array(ary):
 class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
     def norm(self, ary, ord=None):
         from numbers import Number
-        import pyopencl.array as cla
 
         if isinstance(ary, Number):
             return abs(ary)
 
-        if ord is None and isinstance(ary, cla.Array):
+        if ord is None and isinstance(ary, cl_array.Array):
             if ary.ndim == 1:
                 ord = 2
             else:
@@ -291,6 +285,11 @@ class PyOpenCLArrayContext(ArrayContext):
             For now, *wait_event_queue_length* should be regarded as an
             experimental feature that may change or disappear at any minute.
         """
+        # pyopencl is already imported at the top of the file, importing once
+        # again to raise ImportError before the user instantiates a
+        # PyOpenCLArrayContext.
+        import pyopencl as cl  # noqa: F811
+
         super().__init__()
         self.context = queue.context
         self.queue = queue
@@ -302,7 +301,6 @@ class PyOpenCLArrayContext(ArrayContext):
         self._wait_event_queue_length = wait_event_queue_length
         self._kernel_name_to_wait_event_queue = {}
 
-        import pyopencl as cl
         if queue.device.type & cl.device_type.GPU:
             if allocator is None:
                 warn("PyOpenCLArrayContext created without an allocator on a GPU. "
@@ -324,18 +322,15 @@ class PyOpenCLArrayContext(ArrayContext):
     # {{{ ArrayContext interface
 
     def empty(self, shape, dtype):
-        import pyopencl.array as cla
-        return cla.empty(self.queue, shape=shape, dtype=dtype,
+        return cl_array.empty(self.queue, shape=shape, dtype=dtype,
                 allocator=self.allocator)
 
     def zeros(self, shape, dtype):
-        import pyopencl.array as cla
-        return cla.zeros(self.queue, shape=shape, dtype=dtype,
+        return cl_array.zeros(self.queue, shape=shape, dtype=dtype,
                 allocator=self.allocator)
 
     def from_numpy(self, array: np.ndarray):
-        import pyopencl.array as cla
-        return cla.to_device(self.queue, array, allocator=self.allocator)
+        return cl_array.to_device(self.queue, array, allocator=self.allocator)
 
     def to_numpy(self, array):
         return array.get(queue=self.queue)
-- 
GitLab