From 8db8f6f97e9158aa4f62a7bf1049f6aac02fd738 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 2 Aug 2023 20:55:53 +0300
Subject: [PATCH] tools: reorder non-top level imports

---
 sumpy/tools.py | 66 +++++++++++++++++++++++++++++---------------------
 1 file changed, 39 insertions(+), 27 deletions(-)

diff --git a/sumpy/tools.py b/sumpy/tools.py
index 25addd18..0a8a8557 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -205,9 +205,8 @@ def build_matrix(op, dtype=None, shape=None):
 
 
 def vector_to_device(queue, vec):
-    from pytools.obj_array import obj_array_vectorize
-
     from pyopencl.array import to_device
+    from pytools.obj_array import obj_array_vectorize
 
     def to_dev(ary):
         return to_device(queue, ary)
@@ -449,12 +448,12 @@ class KernelCacheMixin:
 
     @memoize_method
     def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
-        from sumpy import (code_cache, CACHING_ENABLED, OPT_ENABLED,
-            NO_CACHE_KERNELS)
+        from sumpy import CACHING_ENABLED, NO_CACHE_KERNELS, OPT_ENABLED, code_cache
 
         if CACHING_ENABLED and not (
                 NO_CACHE_KERNELS and self.name in NO_CACHE_KERNELS):
             import loopy.version
+
             from sumpy.version import KERNEL_VERSION
             cache_key = (
                     self.get_cache_key()
@@ -465,8 +464,7 @@ class KernelCacheMixin:
 
             try:
                 result = code_cache[cache_key]
-                logger.debug("{}: kernel cache hit [key={}]".format(
-                    self.name, cache_key))
+                logger.debug("%s: kernel cache hit [key=%s]", self.name, cache_key)
                 return result.executor(self.context)
             except KeyError:
                 pass
@@ -678,7 +676,8 @@ class ProfileGetter:
 
 
 def get_native_event(evt):
-    return evt if isinstance(evt, cl.Event) else evt.native_event
+    from pyopencl import Event
+    return evt if isinstance(evt, Event) else evt.native_event
 
 
 class AggregateProfilingEvent:
@@ -719,9 +718,11 @@ class MarkerBasedProfilingEvent:
 
 def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
         name=None):
-    from pymbolic.algorithm import find_factors
     from math import pi
 
+    from pymbolic import var
+    from pymbolic.algorithm import find_factors
+
     sign = 1 if not inverse else -1
     n = shape[-1]
 
@@ -733,7 +734,7 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
 
     nfft = n
 
-    broadcast_dims = tuple(pymbolic.var(f"j{d}") for d in range(len(shape) - 1))
+    broadcast_dims = tuple(var(f"j{d}") for d in range(len(shape) - 1))
 
     domains = [
         "{[i]: 0<=i<n}",
@@ -741,11 +742,11 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
     ]
     domains += [f"{{[j{d}]: 0<=j{d}<{shape[d]} }}" for d in range(len(shape) - 1)]
 
-    x = pymbolic.var("x")
-    y = pymbolic.var("y")
-    i = pymbolic.var("i")
-    i2 = pymbolic.var("i2")
-    i3 = pymbolic.var("i3")
+    x = var("x")
+    y = var("y")
+    i = var("i")
+    i2 = var("i2")
+    i3 = var("i3")
 
     fixed_parameters = {"const": complex_dtype(sign*(-2j)*pi/n), "n": n}
 
@@ -767,16 +768,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
         else:
             init_depends_on = f"update_{ilev-1}"
 
-        temp = pymbolic.var("temp")
-        exp_table = pymbolic.var("exp_table")
-        i = pymbolic.var(f"i_{ilev}")
-        i2 = pymbolic.var(f"i2_{ilev}")
-        ifft = pymbolic.var(f"ifft_{ilev}")
-        iN1 = pymbolic.var(f"iN1_{ilev}")           # noqa: N806
-        iN1_sum = pymbolic.var(f"iN1_sum_{ilev}")   # noqa: N806
-        iN2 = pymbolic.var(f"iN2_{ilev}")           # noqa: N806
-        table_idx = pymbolic.var(f"table_idx_{ilev}")
-        exp = pymbolic.var(f"exp_{ilev}")
+        temp = var("temp")
+        exp_table = var("exp_table")
+        i = var(f"i_{ilev}")
+        i2 = var(f"i2_{ilev}")
+        ifft = var(f"ifft_{ilev}")
+        iN1 = var(f"iN1_{ilev}")           # noqa: N806
+        iN1_sum = var(f"iN1_sum_{ilev}")   # noqa: N806
+        iN2 = var(f"iN2_{ilev}")           # noqa: N806
+        table_idx = var(f"table_idx_{ilev}")
+        exp = var(f"exp_{ilev}")
 
         insns += [
             lp.Assignment(
@@ -879,12 +880,16 @@ def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
 
 
 class FFTBackend(enum.Enum):
+    #: FFT backend based on the vkFFT library.
     pyvkfft = 1
+    #: FFT backend based on :mod:`loopy` used as a fallback.
     loopy = 2
 
 
-def _get_fft_backend(queue) -> FFTBackend:
-    env_val = os.environ.get("SUMPY_FFT_BACKEND", None)
+def _get_fft_backend(queue: "cl.CommandQueue") -> FFTBackend:
+    import os
+
+    env_val = os.environ.get("SUMPY_FFT_BACKEND")
     if env_val:
         if env_val not in ["loopy", "pyvkfft"]:
             raise ValueError("Expected 'loopy' or 'pyvkfft' for SUMPY_FFT_BACKEND. "
@@ -897,13 +902,17 @@ def _get_fft_backend(queue) -> FFTBackend:
         warnings.warn("VkFFT not found. FFT runs will be slower.", stacklevel=3)
         return FFTBackend.loopy
 
-    if queue.properties & cl.command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
+    from pyopencl import command_queue_properties
+
+    if queue.properties & command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
         warnings.warn(
             "VkFFT does not support out of order queues yet. "
             "Falling back to slower implementation.", stacklevel=3)
         return FFTBackend.loopy
 
     import platform
+    import sys
+
     if (sys.platform == "darwin"
             and platform.machine() == "x86_64"
             and queue.context.devices[0].platform.name
@@ -960,6 +969,9 @@ def run_opencl_fft(
         if wait_for is None:
             wait_for = []
 
+        import pyopencl as cl
+        import pyopencl.array as cla
+
         start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
 
         if app.inplace:
-- 
GitLab