From e334008d19775abf2fe981284aa6436447e0a409 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 18 Jun 2021 14:39:09 -0500 Subject: [PATCH] import pytato / pyopencl only once --- arraycontext/impl/pytato.py | 41 +++++++------------------------------ 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 8e8cbbc..773ee98 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -38,6 +38,13 @@ import loopy as lp from dataclasses import dataclass, field from pyrsistent import pmap, PMap +try: + import pyopencl as cl # noqa: F401 + import pyopencl.array as cla + import pytato as pt +except ImportError: + pass + class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): # Everything is implemented in the base class for now. @@ -63,97 +70,76 @@ class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", "sqrt", "exp"] if name in pt_funcs: - import pytato as pt # type: ignore from functools import partial return partial(rec_map_array_container, getattr(pt, name)) return super().__getattr__(name) def reshape(self, a, newshape): - import pytato as pt return rec_multimap_array_container(pt.reshape, a, newshape) def transpose(self, a, axes=None): - import pytato as pt return rec_multimap_array_container(pt.transpose, a, axes) def concatenate(self, arrays, axis=0): - import pytato as pt return rec_multimap_array_container(pt.concatenate, arrays, axis) def ones_like(self, ary): def _ones_like(subary): - import pytato as pt return pt.ones(subary.shape, subary.dtype) return self._new_like(ary, _ones_like) def maximum(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.maximum, x, y) def minimum(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.minimum, x, y) def where(self, criterion, then, else_): - import pytato as pt return rec_multimap_array_container(pt.where, criterion, then, else_) def sum(self, a, dtype=None): - import pytato as pt if dtype not in [a.dtype, None]: raise NotImplementedError return pt.sum(a) def min(self, a): - import pytato as pt return pt.amin(a) def max(self, a): - import pytato as pt return pt.amax(a) def stack(self, arrays, axis=0): - import pytato as pt return rec_multimap_array_container(pt.stack, arrays, axis) # {{{ relational operators def equal(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.equal, x, y) def not_equal(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.not_equal, x, y) def greater(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.greater, x, y) def greater_equal(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.greater_equal, x, y) def less(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.less, x, y) def less_equal(self, x, y): - import pytato as pt return rec_multimap_array_container(pt.less_equal, x, y) def conj(self, x): - import pytato as pt return rec_multimap_array_container(pt.conj, x) def arctan2(self, y, x): - import pytato as pt return rec_multimap_array_container(pt.arctan2, y, x) def ravel(self, a, order="C"): - import pytato as pt def _rec_ravel(a): if order in "FC": @@ -197,7 +183,6 @@ class PytatoCompiledOperator: from arraycontext.container.traversal import (rec_keyed_map_array_container, is_array_container) - import pytato as pt def to_arg_descr(arg): if np.isscalar(arg): @@ -293,8 +278,6 @@ class PytatoExecutable: self.output_template = output_template def __call__(self, *args): - import pytato as pt - import pyopencl.array as cla from arraycontext.container import is_array_container from arraycontext.container.traversal import rec_keyed_map_array_container @@ -389,12 +372,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise ValueError("PytatoPyOpenCLArrayContext does not support empty") def zeros(self, shape, dtype): - import pytato as pt return pt.zeros(shape, dtype) def from_numpy(self, np_array: np.ndarray): - import pytato as pt - import pyopencl.array as cla cl_array = cla.to_device(self.queue, np_array) return pt.make_data_wrapper(cl_array) @@ -404,7 +384,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def call_loopy(self, program, **kwargs): from pytato.loopy import call_loopy # type: ignore - import pyopencl.array as cla entrypoint, = set(program.callables_table) # thaw frozen arrays @@ -414,8 +393,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return call_loopy(program, kwargs, entrypoint) def freeze(self, array): - import pytato as pt - import pyopencl.array as cla if isinstance(array, pt.Placeholder): raise ValueError("freezing placeholder would return garbage valued" @@ -433,8 +410,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return cl_array.with_queue(None) def thaw(self, array): - import pytato as pt - import pyopencl.array as cla if not isinstance(array, cla.Array): raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got " @@ -513,8 +488,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext): warn("'arg_names' don't bear any significance in " "PytatoPyOpenCLArrayContext.", stacklevel=2) - import pytato as pt - import pyopencl.array as cla def preprocess_arg(arg): if isinstance(arg, cla.Array): -- GitLab