From bcb020d0cdbf085209fc9a9f9dfdb966d4de27fa Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 25 Jul 2023 10:48:16 -0500
Subject: [PATCH] Make use of loopy.TranslationUnit.executor

This avoids long-lived references to CL kernels held by loopy caches
---
 sumpy/e2e.py   | 16 ++++++++--------
 sumpy/e2p.py   |  6 +++---
 sumpy/p2e.py   |  2 +-
 sumpy/p2p.py   |  8 ++++----
 sumpy/qbx.py   |  6 +++---
 sumpy/tools.py | 15 ++++++++++++---
 6 files changed, 31 insertions(+), 22 deletions(-)

diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index 0a0726a0..9a732378 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -82,7 +82,7 @@ class E2EBase(KernelCacheMixin, ABC):
                     SourceTransformationRemover()(
                         TargetTransformationRemover()(tgt_expansion.kernel)))
 
-        self.ctx = ctx
+        self.context = ctx
         self.src_expansion = src_expansion
         self.tgt_expansion = tgt_expansion
         self.name = name or self.default_name
@@ -297,7 +297,7 @@ class E2EFromCSR(E2EBase):
         src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
         tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
 
-        knl = self.get_cached_optimized_kernel()
+        knl = self.get_cached_kernel_executor()
 
         return knl(queue,
                 centers=centers,
@@ -537,7 +537,7 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
         tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
         src_expansions = kwargs.pop("src_expansions")
 
-        knl = self.get_cached_optimized_kernel(result_dtype=src_expansions.dtype)
+        knl = self.get_cached_kernel_executor(result_dtype=src_expansions.dtype)
 
         return knl(queue,
                 src_expansions=src_expansions,
@@ -647,7 +647,7 @@ class M2LGenerateTranslationClassesDependentData(E2EBase):
                 "m2l_translation_classes_dependent_data")
         result_dtype = m2l_translation_classes_dependent_data.dtype
 
-        knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+        knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
 
         return knl(queue,
                 src_rscale=src_rscale,
@@ -741,7 +741,7 @@ class M2LPreprocessMultipole(E2EBase):
         """
         preprocessed_src_expansions = kwargs.pop("preprocessed_src_expansions")
         result_dtype = preprocessed_src_expansions.dtype
-        knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+        knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
 
         return knl(queue,
                 preprocessed_src_expansions=preprocessed_src_expansions, **kwargs)
@@ -840,7 +840,7 @@ class M2LPostprocessLocal(E2EBase):
         """
         tgt_expansions = kwargs.pop("tgt_expansions")
         result_dtype = tgt_expansions.dtype
-        knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+        knl = self.get_cached_kernel_executor(result_dtype=result_dtype)
 
         return knl(queue, tgt_expansions=tgt_expansions, **kwargs)
 
@@ -950,7 +950,7 @@ class E2EFromChildren(E2EBase):
         :arg tgt_rscale:
         :arg centers:
         """
-        knl = self.get_cached_optimized_kernel()
+        knl = self.get_cached_kernel_executor()
 
         centers = kwargs.pop("centers")
         # "1" may be passed for rscale, which won't have its type
@@ -1054,7 +1054,7 @@ class E2EFromParent(E2EBase):
         :arg tgt_rscale:
         :arg centers:
         """
-        knl = self.get_cached_optimized_kernel()
+        knl = self.get_cached_kernel_executor()
 
         centers = kwargs.pop("centers")
         # "1" may be passed for rscale, which won't have its type
diff --git a/sumpy/e2p.py b/sumpy/e2p.py
index eb9038dc..0ec1c48b 100644
--- a/sumpy/e2p.py
+++ b/sumpy/e2p.py
@@ -68,7 +68,7 @@ class E2PBase(KernelCacheMixin, ABC):
         for knl in kernels:
             assert txr(knl) == expansion.kernel
 
-        self.ctx = ctx
+        self.context = ctx
         self.expansion = expansion
         self.kernels = kernels
         self.name = name or self.default_name
@@ -210,7 +210,7 @@ class E2PFromSingleBox(E2PBase):
         :arg centers:
         :arg targets:
         """
-        knl = self.get_cached_optimized_kernel()
+        knl = self.get_cached_kernel_executor()
 
         centers = kwargs.pop("centers")
         # "1" may be passed for rscale, which won't have its type
@@ -327,7 +327,7 @@ class E2PFromCSR(E2PBase):
         return knl
 
     def __call__(self, queue, **kwargs):
-        knl = self.get_cached_optimized_kernel()
+        knl = self.get_cached_kernel_executor()
 
         centers = kwargs.pop("centers")
         # "1" may be passed for rscale, which won't have its type
diff --git a/sumpy/p2e.py b/sumpy/p2e.py
index 7118bc2f..fe52f6b9 100644
--- a/sumpy/p2e.py
+++ b/sumpy/p2e.py
@@ -124,7 +124,7 @@ class P2EBase(KernelCacheMixin, KernelComputation):
         from sumpy.tools import is_obj_array_like
         sources = kwargs.pop("sources")
         centers = kwargs.pop("centers")
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 sources_is_obj_array=is_obj_array_like(sources),
                 centers_is_obj_array=is_obj_array_like(centers))
 
diff --git a/sumpy/p2p.py b/sumpy/p2p.py
index d03d8bd1..0b986c7d 100644
--- a/sumpy/p2p.py
+++ b/sumpy/p2p.py
@@ -256,7 +256,7 @@ class P2P(P2PBase):
         return loopy_knl
 
     def __call__(self, queue, targets, sources, strength, **kwargs):
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources))
 
@@ -318,7 +318,7 @@ class P2PMatrixGenerator(P2PBase):
         return loopy_knl
 
     def __call__(self, queue, targets, sources, **kwargs):
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources))
 
@@ -429,7 +429,7 @@ class P2PMatrixSubsetGenerator(P2PBase):
         :returns: a one-dimensional array of interactions, for each index pair
             in (*srcindices*, *tgtindices*)
         """
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources))
 
@@ -731,7 +731,7 @@ class P2PFromCSR(P2PBase):
         else:
             dtype_size = None
 
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 max_nsources_in_one_box=max_nsources_in_one_box,
                 max_ntargets_in_one_box=max_ntargets_in_one_box,
                 dtype_size=dtype_size,
diff --git a/sumpy/qbx.py b/sumpy/qbx.py
index bd9d8fd2..9c903ab1 100644
--- a/sumpy/qbx.py
+++ b/sumpy/qbx.py
@@ -288,7 +288,7 @@ class LayerPotential(LayerPotentialBase):
             already multiplied in.
         """
 
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources),
                 centers_is_obj_array=is_obj_array_like(centers))
@@ -359,7 +359,7 @@ class LayerPotentialMatrixGenerator(LayerPotentialBase):
         return loopy_knl
 
     def __call__(self, queue, targets, sources, centers, expansion_radii, **kwargs):
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources),
                 centers_is_obj_array=is_obj_array_like(centers))
@@ -479,7 +479,7 @@ class LayerPotentialMatrixSubsetGenerator(LayerPotentialBase):
             in (*srcindices*, *tgtindices*)
         """
 
-        knl = self.get_cached_optimized_kernel(
+        knl = self.get_cached_kernel_executor(
                 targets_is_obj_array=is_obj_array_like(targets),
                 sources_is_obj_array=is_obj_array_like(sources),
                 centers_is_obj_array=is_obj_array_like(centers))
diff --git a/sumpy/tools.py b/sumpy/tools.py
index 9dfc24a5..be404fe9 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -381,8 +381,17 @@ class OrderedSet(MutableSet):
 
 
 class KernelCacheMixin:
-    @memoize_method
     def get_cached_optimized_kernel(self, **kwargs):
+        from warnings import warn
+        warn("get_cached_optimized_kernel is deprecated. "
+             "Use get_cached_kernel_executor instead. "
+             "This will stop working in October 2023.",
+             DeprecationWarning, stacklevel=2)
+
+        return self.get_cached_kernel_executor(**kwargs)
+
+    @memoize_method
+    def get_cached_kernel_executor(self, **kwargs) -> lp.ExecutorBase:
         from sumpy import (code_cache, CACHING_ENABLED, OPT_ENABLED,
             NO_CACHE_KERNELS)
 
@@ -401,7 +410,7 @@ class KernelCacheMixin:
                 result = code_cache[cache_key]
                 logger.debug("{}: kernel cache hit [key={}]".format(
                     self.name, cache_key))
-                return result
+                return result.executor(self.context)
             except KeyError:
                 pass
 
@@ -422,7 +431,7 @@ class KernelCacheMixin:
                 NO_CACHE_KERNELS and self.name in NO_CACHE_KERNELS):
             code_cache.store_if_not_present(cache_key, knl)
 
-        return knl
+        return knl.executor(self.context)
 
     @staticmethod
     def _allow_redundant_execution_of_knl_scaling(knl):
-- 
GitLab