diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py
index 2e79efc9b0906318709de14c07cec68c06198e1a..a11c6732e438e0a0021b218b8ab9888e3f62bfde 100644
--- a/pyopencl/invoker.py
+++ b/pyopencl/invoker.py
@@ -29,9 +29,8 @@ import numpy as np
 
 from warnings import warn
 from pyopencl._cffi import ffi as _ffi
-from pytools.persistent_dict import (
-        PersistentDict,
-        KeyBuilder as KeyBuilderBase)
+from pytools.persistent_dict import PersistentDict
+from pyopencl.tools import _NumpyTypesKeyBuilder
 
 _PYPY = '__pypy__' in sys.builtin_module_names
 _CPY2 = not _PYPY and sys.version_info < (3,)
@@ -359,18 +358,8 @@ def _generate_enqueue_and_set_args_module(function_name,
     return gen.get_picklable_module(), enqueue_name
 
 
-class NumpyTypesKeyBuilder(KeyBuilderBase):
-    def update_for_type(self, key_hash, key):
-        if issubclass(key, np.generic):
-            self.update_for_str(key_hash, key.__name__)
-            return
-
-        raise TypeError("unsupported type for persistent hash keying: %s"
-                % type(key))
-
-
 invoker_cache = PersistentDict("pyopencl-invoker-cache-v1",
-        key_builder=NumpyTypesKeyBuilder())
+        key_builder=_NumpyTypesKeyBuilder())
 
 
 def generate_enqueue_and_set_args(function_name,
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index f9166abee51000a984ed3a34c01ac08c13cf42dd..169939cab356e0f1ba7be00e7fa82e19e22979a5 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -34,11 +34,18 @@ import pyopencl.array  # noqa
 from pyopencl.tools import (dtype_to_ctype, bitlog2,
         KernelTemplateBase, _process_code_for_macro,
         get_arg_list_scalar_arg_dtypes,
-        context_dependent_memoize)
+        context_dependent_memoize,
+        _NumpyTypesKeyBuilder)
 
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
 
+from pytools.persistent_dict import PersistentDict
+
+
+import logging
+logger = logging.getLogger(__name__)
+
 
 # {{{ preamble
 
@@ -854,11 +861,71 @@ def _make_template(s):
     return mako.template.Template(s, strict_undefined=True)
 
 
-from pytools import Record
+from pytools import Record, RecordWithoutPickling
 
 
-class _ScanKernelInfo(Record):
-    pass
+class _GeneratedScanKernelInfo(Record):
+
+    __slots__ = [
+            "scan_src",
+            "kernel_name",
+            "scalar_arg_dtypes",
+            "wg_size",
+            "k_group_size"]
+
+    def __init__(self, scan_src, kernel_name, scalar_arg_dtypes, wg_size,
+            k_group_size):
+        Record.__init__(self,
+                scan_src=scan_src,
+                kernel_name=kernel_name,
+                scalar_arg_dtypes=scalar_arg_dtypes,
+                wg_size=wg_size,
+                k_group_size=k_group_size)
+
+    def build(self, context, options):
+        program = cl.Program(context, self.scan_src).build(options)
+        kernel = getattr(program, self.kernel_name)
+        kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
+        return _BuiltScanKernelInfo(
+                kernel=kernel,
+                wg_size=self.wg_size,
+                k_group_size=self.k_group_size)
+
+
+class _BuiltScanKernelInfo(RecordWithoutPickling):
+
+    __slots__ = ["kernel", "wg_size", "k_group_size"]
+
+    def __init__(self, kernel, wg_size, k_group_size):
+        RecordWithoutPickling.__init__(self,
+                kernel=kernel,
+                wg_size=wg_size,
+                k_group_size=k_group_size)
+
+
+class _GeneratedFinalUpdateKernelInfo(Record):
+
+    def __init__(self, source, kernel_name, scalar_arg_dtypes, update_wg_size):
+        Record.__init__(self,
+                source=source,
+                kernel_name=kernel_name,
+                scalar_arg_dtypes=scalar_arg_dtypes,
+                update_wg_size=update_wg_size)
+
+    def build(self, context, options):
+        program = cl.Program(context, self.source).build(options)
+        kernel = getattr(program, self.kernel_name)
+        kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
+        return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
+
+
+class _BuiltFinalUpdateKernelInfo(RecordWithoutPickling):
+    __slots__ = ["kernel", "update_wg_size"]
+
+    def __init__(self, kernel, update_wg_size):
+        RecordWithoutPickling.__init__(self,
+                kernel=kernel,
+                update_wg_size=update_wg_size)
 
 # }}}
 
@@ -1031,13 +1098,45 @@ class _GenericScanKernelBase(object):
                 has_double_support(dev) for dev in devices),
             )
 
+        index_typename = dtype_to_ctype(self.index_dtype)
+        scan_typename = dtype_to_ctype(dtype)
+
+        # This key is meant to uniquely identify the non-device parameters for
+        # the scan kernel.
+        self.kernel_key = (
+            self.dtype,
+            tuple(arg.declarator() for arg in self.parsed_args),
+            self.input_expr,
+            scan_expr,
+            neutral,
+            output_statement,
+            is_segment_start_expr,
+            tuple(input_fetch_exprs),
+            index_dtype,
+            name_prefix,
+            preamble,
+            # These depend on dtype_to_ctype(), so their value is independent of
+            # the other variables.
+            index_typename,
+            scan_typename,
+            )
+
         # }}}
 
+        self.use_lookbehind_update = "prev_item" in self.output_statement
+        self.store_segment_start_flags = (
+                self.is_segmented and self.use_lookbehind_update)
+
         self.finish_setup()
 
     # }}}
 
 
+generic_scan_kernel_cache = PersistentDict(
+        "pyopencl-generated-scan-kernel-cache-v1",
+        key_builder=_NumpyTypesKeyBuilder())
+
+
 class GenericScanKernel(_GenericScanKernelBase):
     """Generates and executes code that performs prefix sums ("scans") on
     arbitrary types, with many possible tweaks.
@@ -1058,9 +1157,52 @@ class GenericScanKernel(_GenericScanKernelBase):
     """
 
     def finish_setup(self):
-        use_lookbehind_update = "prev_item" in self.output_statement
-        self.store_segment_start_flags = self.is_segmented and use_lookbehind_update
-
+        # Before generating the kernel, see if it's cached.
+        from pyopencl.cache import get_device_cache_id
+        devices_key = tuple(get_device_cache_id(device)
+                for device in self.devices)
+
+        cache_key = (self.kernel_key, devices_key)
+
+        from_cache = False
+
+        try:
+            result = generic_scan_kernel_cache[cache_key]
+            from_cache = True
+            logger.debug(
+                    "cache hit for generated scan kernel '%s'" % self.name_prefix)
+            (
+                self.first_level_scan_gen_info,
+                self.second_level_scan_gen_info,
+                self.final_update_gen_info) = result
+        except KeyError:
+            pass
+
+        if not from_cache:
+            logger.debug(
+                    "cache miss for generated scan kernel '%s'" % self.name_prefix)
+            self._finish_setup_impl()
+
+            result = (self.first_level_scan_gen_info,
+                      self.second_level_scan_gen_info,
+                      self.final_update_gen_info)
+
+            generic_scan_kernel_cache[cache_key] = result
+
+        # Build the kernels.
+        self.first_level_scan_info = self.first_level_scan_gen_info.build(
+                self.context, self.options)
+        del self.first_level_scan_gen_info
+
+        self.second_level_scan_info = self.second_level_scan_gen_info.build(
+                self.context, self.options)
+        del self.second_level_scan_gen_info
+
+        self.final_update_info = self.final_update_gen_info.build(
+                self.context, self.options)
+        del self.final_update_gen_info
+
+    def _finish_setup_impl(self):
         # {{{ find usable workgroup/k-group size, build first-level scan
 
         trip_count = 0
@@ -1118,7 +1260,7 @@ class GenericScanKernel(_GenericScanKernelBase):
         _, k_group_size, max_scan_wg_size = max(solutions)
 
         while True:
-            candidate_scan_info = self.build_scan_kernel(
+            candidate_scan_gen_info = self.generate_scan_kernel(
                     max_scan_wg_size, self.parsed_args,
                     _process_code_for_macro(self.input_expr),
                     self.is_segment_start_expr,
@@ -1128,6 +1270,9 @@ class GenericScanKernel(_GenericScanKernelBase):
                     k_group_size=k_group_size,
                     use_bank_conflict_avoidance=use_bank_conflict_avoidance)
 
+            candidate_scan_info = candidate_scan_gen_info.build(
+                    self.context, self.options)
+
             # Will this device actually let us execute this kernel
             # at the desired work group size? Building it is the
             # only way to find out.
@@ -1145,7 +1290,7 @@ class GenericScanKernel(_GenericScanKernelBase):
             trip_count += 1
             assert trip_count <= 20
 
-        self.first_level_scan_info = candidate_scan_info
+        self.first_level_scan_gen_info = candidate_scan_gen_info
         assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
                 == candidate_scan_info.wg_size)
 
@@ -1173,7 +1318,7 @@ class GenericScanKernel(_GenericScanKernelBase):
         else:
             second_level_build_kwargs["is_segment_start_expr"] = None
 
-        self.second_level_scan_info = self.build_scan_kernel(
+        self.second_level_scan_gen_info = self.generate_scan_kernel(
                 max_scan_wg_size,
                 arguments=second_level_arguments,
                 input_expr="interval_sums[i]",
@@ -1186,26 +1331,21 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         # }}}
 
-        # {{{ build final update kernel
+        # {{{ generate final update kernel
 
-        self.update_wg_size = min(max_scan_wg_size, 256)
+        update_wg_size = min(max_scan_wg_size, 256)
 
         final_update_tpl = _make_template(UPDATE_SOURCE)
         final_update_src = str(final_update_tpl.render(
-            wg_size=self.update_wg_size,
+            wg_size=update_wg_size,
             output_statement=self.output_statement,
             argument_signature=", ".join(
                 arg.declarator() for arg in self.parsed_args),
             is_segment_start_expr=self.is_segment_start_expr,
             input_expr=_process_code_for_macro(self.input_expr),
-            use_lookbehind_update=use_lookbehind_update,
+            use_lookbehind_update=self.use_lookbehind_update,
             **self.code_variables))
 
-        final_update_prg = cl.Program(
-                self.context, final_update_src).build(self.options)
-        self.final_update_knl = getattr(
-                final_update_prg,
-                self.name_prefix+"_final_update")
         update_scalar_arg_dtypes = (
                 get_arg_list_scalar_arg_dtypes(self.parsed_args)
                 + [self.index_dtype, self.index_dtype, None, None])
@@ -1214,7 +1354,12 @@ class GenericScanKernel(_GenericScanKernelBase):
             update_scalar_arg_dtypes.append(None)
         if self.store_segment_start_flags:
             update_scalar_arg_dtypes.append(None)  # g_segment_start_flags
-        self.final_update_knl.set_scalar_arg_dtypes(update_scalar_arg_dtypes)
+
+        self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
+                final_update_src,
+                self.name_prefix + "_final_update",
+                update_scalar_arg_dtypes,
+                update_wg_size)
 
         # }}}
 
@@ -1248,7 +1393,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                     for arg_name, ife_offsets in list(fetch_expr_offsets.items())
                     if -1 in ife_offsets or len(ife_offsets) > 1))
 
-    def build_scan_kernel(self, max_wg_size, arguments, input_expr,
+    def generate_scan_kernel(self, max_wg_size, arguments, input_expr,
             is_segment_start_expr, input_fetch_exprs, is_first_level,
             store_segment_start_flags, k_group_size,
             use_bank_conflict_avoidance):
@@ -1278,22 +1423,21 @@ class GenericScanKernel(_GenericScanKernelBase):
             kernel_name=kernel_name,
             **self.code_variables))
 
-        prg = cl.Program(self.context, scan_src).build(self.options)
-
-        knl = getattr(prg, kernel_name)
-
         scalar_arg_dtypes.extend(
-                (None, self.index_dtype, self. index_dtype))
+                (None, self.index_dtype, self.index_dtype))
         if is_first_level:
             scalar_arg_dtypes.append(None)  # interval_results
         if self.is_segmented and is_first_level:
             scalar_arg_dtypes.append(None)  # g_first_segment_start_in_interval
         if store_segment_start_flags:
             scalar_arg_dtypes.append(None)  # g_segment_start_flags
-        knl.set_scalar_arg_dtypes(scalar_arg_dtypes)
 
-        return _ScanKernelInfo(
-                kernel=knl, wg_size=wg_size, knl=knl, k_group_size=k_group_size)
+        return _GeneratedScanKernelInfo(
+                scan_src=scan_src,
+                kernel_name=kernel_name,
+                scalar_arg_dtypes=scalar_arg_dtypes,
+                wg_size=wg_size,
+                k_group_size=k_group_size)
 
     # }}}
 
@@ -1408,8 +1552,9 @@ class GenericScanKernel(_GenericScanKernelBase):
         if self.store_segment_start_flags:
             upd_args.append(segment_start_flags.data)
 
-        return self.final_update_knl(
-                queue, (num_intervals,), (self.update_wg_size,),
+        return self.final_update_info.kernel(
+                queue, (num_intervals,),
+                (self.final_update_info.update_wg_size,),
                 *upd_args, **dict(g_times_l=True, wait_for=[l2_evt]))
 
         # }}}
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 5efdfdb0c6199b44aafa426f0b49f79aa4ceae39..d974980b8a17381e82994a84bd1860b93ef0ca98 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -36,6 +36,7 @@ from decorator import decorator
 import pyopencl as cl
 from pytools import memoize, memoize_method
 from pyopencl.cffi_cl import _lib
+from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
 
 import re
 
@@ -957,4 +958,18 @@ def is_spirv(s):
                 s[:4] == spirv_magic
                 or s[:4] == spirv_magic[::-1]))
 
+
+# {{{ numpy key types builder
+
+class _NumpyTypesKeyBuilder(KeyBuilderBase):
+    def update_for_type(self, key_hash, key):
+        if issubclass(key, np.generic):
+            self.update_for_str(key_hash, key.__name__)
+            return
+
+        raise TypeError("unsupported type for persistent hash keying: %s"
+                % type(key))
+
+# }}}
+
 # vim: foldmethod=marker