From 53e3857ed70f68150862859419162178486f1ffa Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Fri, 18 Jun 2021 09:46:50 -0500 Subject: [PATCH] import pyopencl.array once at the top --- arraycontext/impl/pyopencl.py | 43 ++++++++++++++++------------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 80f0bb4..0b96ed8 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -43,6 +43,12 @@ from arraycontext.container.traversal import (rec_multimap_array_container, from arraycontext.container import serialize_container, is_array_container from arraycontext.context import ArrayContext +try: + import pyopencl as cl # noqa: F401 + import pyopencl.array as cl_array +except ImportError: + pass + # {{{ fake numpy @@ -88,20 +94,16 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return self._new_like(ary, _ones_like) def maximum(self, x, y): - import pyopencl.array as cl_array return rec_multimap_array_container( partial(cl_array.maximum, queue=self._array_context.queue), x, y) def minimum(self, x, y): - import pyopencl.array as cl_array return rec_multimap_array_container( partial(cl_array.minimum, queue=self._array_context.queue), x, y) def where(self, criterion, then, else_): - import pyopencl.array as cl_array - def where_inner(inner_crit, inner_then, inner_else): if isinstance(inner_crit, bool): return inner_then if inner_crit else inner_else @@ -111,32 +113,26 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_multimap_array_container(where_inner, criterion, then, else_) def sum(self, a, dtype=None): - import pyopencl.array as cl_array return cl_array.sum( a, dtype=dtype, queue=self._array_context.queue).get()[()] def min(self, a): - import pyopencl.array as cl_array return cl_array.min(a, queue=self._array_context.queue).get()[()] def max(self, a): - import pyopencl.array as cl_array return cl_array.max(a, queue=self._array_context.queue).get()[()] def stack(self, arrays, axis=0): - import pyopencl.array as cla return rec_multimap_array_container( - lambda *args: cla.stack(arrays=args, axis=axis, + lambda *args: cl_array.stack(arrays=args, axis=axis, queue=self._array_context.queue), *arrays) def reshape(self, a, newshape): - import pyopencl.array as cla - return cla.reshape(a, newshape) + return cl_array.reshape(a, newshape) def concatenate(self, arrays, axis=0): - import pyopencl.array as cla - return cla.concatenate( + return cl_array.concatenate( arrays, axis, self._array_context.queue, self._array_context.allocator @@ -170,8 +166,7 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): # {{{ fake np.linalg def _flatten_array(ary): - import pyopencl.array as cl - assert isinstance(ary, cl.Array) + assert isinstance(ary, cl_array.Array) if ary.size == 0: # Work around https://github.com/inducer/pyopencl/pull/402 @@ -189,12 +184,11 @@ def _flatten_array(ary): class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): def norm(self, ary, ord=None): from numbers import Number - import pyopencl.array as cla if isinstance(ary, Number): return abs(ary) - if ord is None and isinstance(ary, cla.Array): + if ord is None and isinstance(ary, cl_array.Array): if ary.ndim == 1: ord = 2 else: @@ -291,6 +285,11 @@ class PyOpenCLArrayContext(ArrayContext): For now, *wait_event_queue_length* should be regarded as an experimental feature that may change or disappear at any minute. """ + # pyopencl is already imported at the top of the file, importing once + # again to raise ImportError before the user instantiates a + # PyOpenCLArrayContext. + import pyopencl as cl # noqa: F811 + super().__init__() self.context = queue.context self.queue = queue @@ -302,7 +301,6 @@ class PyOpenCLArrayContext(ArrayContext): self._wait_event_queue_length = wait_event_queue_length self._kernel_name_to_wait_event_queue = {} - import pyopencl as cl if queue.device.type & cl.device_type.GPU: if allocator is None: warn("PyOpenCLArrayContext created without an allocator on a GPU. " @@ -324,18 +322,15 @@ class PyOpenCLArrayContext(ArrayContext): # {{{ ArrayContext interface def empty(self, shape, dtype): - import pyopencl.array as cla - return cla.empty(self.queue, shape=shape, dtype=dtype, + return cl_array.empty(self.queue, shape=shape, dtype=dtype, allocator=self.allocator) def zeros(self, shape, dtype): - import pyopencl.array as cla - return cla.zeros(self.queue, shape=shape, dtype=dtype, + return cl_array.zeros(self.queue, shape=shape, dtype=dtype, allocator=self.allocator) def from_numpy(self, array: np.ndarray): - import pyopencl.array as cla - return cla.to_device(self.queue, array, allocator=self.allocator) + return cl_array.to_device(self.queue, array, allocator=self.allocator) def to_numpy(self, array): return array.get(queue=self.queue) -- GitLab