From 528f78f67b9421d7e9b149db4678ce8812f6859f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 2 Jun 2017 18:46:31 -0400
Subject: [PATCH] Cache binary code for kernel invokers to disk

---
 pyopencl/cffi_cl.py |   4 +-
 pyopencl/invoker.py | 112 ++++++++++++++++++++++++++++++++------------
 setup.py            |   2 +-
 3 files changed, 85 insertions(+), 33 deletions(-)

diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py
index e8b5b289..2d76f916 100644
--- a/pyopencl/cffi_cl.py
+++ b/pyopencl/cffi_cl.py
@@ -1687,7 +1687,7 @@ class Kernel(_Common):
         return self
 
     def set_scalar_arg_dtypes(self, scalar_arg_dtypes):
-        self._scalar_arg_dtypes = scalar_arg_dtypes
+        self._scalar_arg_dtypes = tuple(scalar_arg_dtypes)
 
         # {{{ arg counting bug handling
 
@@ -1717,7 +1717,7 @@ class Kernel(_Common):
         self._enqueue, self._set_args = generate_enqueue_and_set_args(
                 self.function_name,
                 len(scalar_arg_dtypes), self.num_args,
-                scalar_arg_dtypes,
+                self._scalar_arg_dtypes,
                 warn_about_arg_count_bug=warn_about_arg_count_bug,
                 work_around_arg_count_bug=work_around_arg_count_bug)
 
diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py
index abb1b462..2e79efc9 100644
--- a/pyopencl/invoker.py
+++ b/pyopencl/invoker.py
@@ -29,6 +29,9 @@ import numpy as np
 
 from warnings import warn
 from pyopencl._cffi import ffi as _ffi
+from pytools.persistent_dict import (
+        PersistentDict,
+        KeyBuilder as KeyBuilderBase)
 
 _PYPY = '__pypy__' in sys.builtin_module_names
 _CPY2 = not _PYPY and sys.version_info < (3,)
@@ -285,24 +288,21 @@ def wrap_in_error_handler(body, arg_names):
 # }}}
 
 
-def add_module_preamble(gen):
-    gen.add_to_preamble(
-        "import numpy as np")
-    gen.add_to_preamble(
-        "import pyopencl.cffi_cl as _cl")
-    gen.add_to_preamble(
+def add_local_imports(gen):
+    gen("import numpy as np")
+    gen("import pyopencl.cffi_cl as _cl")
+    gen(
         "from pyopencl.cffi_cl import _lib, "
         "_ffi, _handle_error, _CLKernelArg")
-    gen.add_to_preamble("from pyopencl import status_code")
-    gen.add_to_preamble("from struct import pack")
-    gen.add_to_preamble("")
+    gen("")
 
 
-def generate_enqueue_and_set_args(function_name,
+def _generate_enqueue_and_set_args_module(function_name,
         num_passed_args, num_cl_args,
         scalar_arg_dtypes,
-        work_around_arg_count_bug, warn_about_arg_count_bug,):
-    from pytools.py_codegen import PythonFunctionGenerator
+        work_around_arg_count_bug, warn_about_arg_count_bug):
+
+    from pytools.py_codegen import PythonCodeGenerator, Indentation
 
     arg_names = ["arg%d" % i for i in range(num_passed_args)]
 
@@ -316,37 +316,89 @@ def generate_enqueue_and_set_args(function_name,
 
     err_handler = wrap_in_error_handler(body, arg_names)
 
-    # {{{ generate _enqueue
+    gen = PythonCodeGenerator()
 
-    gen = PythonFunctionGenerator("enqueue_knl_%s" % function_name,
-            ["self", "queue", "global_size", "local_size"]
-            + arg_names
-            + ["global_offset=None", "g_times_l=None", "wait_for=None"])
+    gen("from struct import pack")
+    gen("from pyopencl import status_code")
+    gen("")
 
-    add_module_preamble(gen)
-    gen.extend(err_handler)
+    # {{{ generate _enqueue
 
-    gen("""
-        return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
-                global_offset, wait_for, g_times_l=g_times_l)
-        """)
+    enqueue_name = "enqueue_knl_%s" % function_name
+    gen("def %s(%s):"
+            % (enqueue_name,
+                ", ".join(
+                    ["self", "queue", "global_size", "local_size"]
+                    + arg_names
+                    + ["global_offset=None", "g_times_l=None",
+                        "wait_for=None"])))
 
-    enqueue = gen.get_function()
+    with Indentation(gen):
+        add_local_imports(gen)
+        gen.extend(err_handler)
+
+        gen("""
+            return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
+                    global_offset, wait_for, g_times_l=g_times_l)
+            """)
 
     # }}}
 
     # {{{ generate set_args
 
-    gen = PythonFunctionGenerator("_set_args", ["self"] + arg_names)
+    gen("")
+    gen("def set_args(%s):"
+            % (", ".join(["self"] + arg_names)))
 
-    add_module_preamble(gen)
-    gen.extend(err_handler)
-
-    set_args = gen.get_function()
+    with Indentation(gen):
+        add_local_imports(gen)
+        gen.extend(err_handler)
 
     # }}}
 
-    return enqueue, set_args
+    return gen.get_picklable_module(), enqueue_name
+
+
+class NumpyTypesKeyBuilder(KeyBuilderBase):
+    def update_for_type(self, key_hash, key):
+        if issubclass(key, np.generic):
+            self.update_for_str(key_hash, key.__name__)
+            return
+
+        raise TypeError("unsupported type for persistent hash keying: %s"
+                % type(key))
+
+
+invoker_cache = PersistentDict("pyopencl-invoker-cache-v1",
+        key_builder=NumpyTypesKeyBuilder())
+
+
+def generate_enqueue_and_set_args(function_name,
+        num_passed_args, num_cl_args,
+        scalar_arg_dtypes,
+        work_around_arg_count_bug, warn_about_arg_count_bug):
+
+    cache_key = (function_name, num_passed_args, num_cl_args,
+            scalar_arg_dtypes,
+            work_around_arg_count_bug, warn_about_arg_count_bug)
+
+    from_cache = False
+
+    try:
+        result = invoker_cache[cache_key]
+        from_cache = True
+    except KeyError:
+        pass
+
+    if not from_cache:
+        result = _generate_enqueue_and_set_args_module(*cache_key)
+        invoker_cache[cache_key] = result
+
+    pmod, enqueue_name = result
+
+    return (
+            pmod.mod_globals[enqueue_name],
+            pmod.mod_globals["set_args"])
 
 # }}}
 
diff --git a/setup.py b/setup.py
index 3d2ddd0e..cdd356ae 100644
--- a/setup.py
+++ b/setup.py
@@ -219,7 +219,7 @@ def main():
 
             install_requires=[
                 "numpy",
-                "pytools>=2015.1.2",
+                "pytools>=2017.2",
                 "pytest>=2",
                 "decorator>=3.2.0",
                 "cffi>=1.1.0",
-- 
GitLab