diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 169939cab356e0f1ba7be00e7fa82e19e22979a5..a94adb8ad538539c408f3ecaed304652f562ffb4 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -4,6 +4,7 @@ from __future__ import division, absolute_import
 
 __copyright__ = """
 Copyright 2011-2012 Andreas Kloeckner
+Copyright 2017 Matt Wala
 Copyright 2008-2011 NVIDIA Corporation
 """
 
@@ -40,6 +41,7 @@ from pyopencl.tools import (dtype_to_ctype, bitlog2,
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
 
+from pytools import Record, RecordWithoutPickling
 from pytools.persistent_dict import PersistentDict
 
 
@@ -70,6 +72,7 @@ typedef ${dtype_to_ctype(index_dtype)} index_type;
 
 # }}}
 
+
 # {{{ main scan code
 
 # Algorithm: Each work group is responsible for one contiguous
@@ -600,6 +603,7 @@ void ${kernel_name}(
 
 # }}}
 
+
 # {{{ update
 
 UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
@@ -739,8 +743,6 @@ void ${name_prefix}_final_update(
 # }}}
 
 
-# {{{ driver
-
 # {{{ helpers
 
 def _round_down_to_power_of_2(val):
@@ -860,9 +862,10 @@ def _make_template(s):
 
     return mako.template.Template(s, strict_undefined=True)
 
+# }}}
 
-from pytools import Record, RecordWithoutPickling
 
+# {{{ data structures for code generation result
 
 class _GeneratedScanKernelInfo(Record):
 
@@ -930,13 +933,302 @@ class _BuiltFinalUpdateKernelInfo(RecordWithoutPickling):
 # }}}
 
 
+# {{{ generation and caching
+
+generic_scan_kernel_cache = PersistentDict(
+        "pyopencl-generated-scan-kernel-cache-v1",
+        key_builder=_NumpyTypesKeyBuilder())
+
+
+def _get_local_mem_use(k_group_size, wg_size, use_bank_conflict_avoidance):
+    arg_dtypes = {}
+    for arg in parsed_args:
+        arg_dtypes[arg.name] = arg.dtype
+
+    fetch_expr_offsets = {}
+    for name, arg_name, ife_offset in input_fetch_exprs:
+        fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
+
+    itemsize = dtype.itemsize
+    if use_bank_conflict_avoidance:
+        itemsize += 4
+
+    return (
+            # ldata
+            itemsize*(k_group_size+1)*(wg_size+1)
+
+            # l_segment_start_flags
+            + k_group_size*wg_size
+
+            # l_first_segment_start_in_subtree
+            + index_dtype.itemsize*wg_size
+
+            + k_group_size*wg_size*sum(
+                arg_dtypes[arg_name].itemsize
+                for arg_name, ife_offsets in list(fetch_expr_offsets.items())
+                if -1 in ife_offsets or len(ife_offsets) > 1))
+
+
+def _generate_scan_kernel(max_wg_size, arguments, input_expr,
+        is_segment_start_expr, input_fetch_exprs, is_first_level,
+        store_segment_start_flags, k_group_size,
+        use_bank_conflict_avoidance, code_variables):
+    scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
+
+    # Empirically found on Nv hardware: no need to be bigger than this size
+    wg_size = _round_down_to_power_of_2(
+            min(max_wg_size, 256))
+
+    kernel_name = code_variables["name_prefix"]
+    if is_first_level:
+        kernel_name += "_lev1"
+    else:
+        kernel_name += "_lev2"
+
+    scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
+    scan_src = str(scan_tpl.render(
+        wg_size=wg_size,
+        input_expr=input_expr,
+        k_group_size=k_group_size,
+        argument_signature=", ".join(arg.declarator() for arg in arguments),
+        is_segment_start_expr=is_segment_start_expr,
+        input_fetch_exprs=input_fetch_exprs,
+        is_first_level=is_first_level,
+        store_segment_start_flags=store_segment_start_flags,
+        use_bank_conflict_avoidance=use_bank_conflict_avoidance,
+        kernel_name=kernel_name,
+        **self.code_variables))
+
+    scalar_arg_dtypes.extend(
+            (None, self.index_dtype, self.index_dtype))
+    if is_first_level:
+        scalar_arg_dtypes.append(None)  # interval_results
+    if self.is_segmented and is_first_level:
+        scalar_arg_dtypes.append(None)  # g_first_segment_start_in_interval
+    if store_segment_start_flags:
+        scalar_arg_dtypes.append(None)  # g_segment_start_flags
+
+    return _GeneratedScanKernelInfo(
+            scan_src=scan_src,
+            kernel_name=kernel_name,
+            scalar_arg_dtypes=scalar_arg_dtypes,
+            wg_size=wg_size,
+            k_group_size=k_group_size)
+
+
+def _generate_scan_code(context, options, dtype, index_dtype, is_segmented):
+    devices = context.devices
+
+    # {{{ find usable workgroup/k-group size, build first-level scan
+
+    trip_count = 0
+
+    avail_local_mem = min(
+            dev.local_mem_size
+            for dev in devices)
+
+    if "CUDA" in devices[0].platform.name:
+        # not sure where these go, but roughly this much seems unavailable.
+        avail_local_mem -= 0x400
+
+    is_cpu = devices[0].type & cl.device_type.CPU
+    is_gpu = devices[0].type & cl.device_type.GPU
+
+    if is_cpu:
+        # (about the widest vector a CPU can support, also taking
+        # into account that CPUs don't hide latency by large work groups
+        max_scan_wg_size = 16
+        wg_size_multiples = 4
+    else:
+        max_scan_wg_size = min(dev.max_work_group_size for dev in devices)
+        wg_size_multiples = 64
+
+    use_bank_conflict_avoidance = (
+            dtype.itemsize > 4 and dtype.itemsize % 8 == 0 and is_gpu)
+
+    # k_group_size should be a power of two because of in-kernel
+    # division by that number.
+
+    solutions = []
+    for k_exp in range(0, 9):
+        for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
+                wg_size_multiples):
+
+            k_group_size = 2**k_exp
+            lmem_use = _get_local_mem_use(k_group_size, wg_size,
+                    use_bank_conflict_avoidance)
+            if lmem_use <= avail_local_mem:
+                solutions.append((wg_size*k_group_size, k_group_size, wg_size))
+
+    if is_gpu:
+        from pytools import any
+        for wg_size_floor in [256, 192, 128]:
+            have_sol_above_floor = any(wg_size >= wg_size_floor
+                    for _, _, wg_size in solutions)
+
+            if have_sol_above_floor:
+                # delete all solutions not meeting the wg size floor
+                solutions = [(total, try_k_group_size, try_wg_size)
+                        for total, try_k_group_size, try_wg_size in solutions
+                        if try_wg_size >= wg_size_floor]
+                break
+
+    _, k_group_size, max_scan_wg_size = max(solutions)
+
+    while True:
+        candidate_scan_gen_info = _generate_scan_kernel(
+                max_scan_wg_size, self.parsed_args,
+                _process_code_for_macro(self.input_expr),
+                self.is_segment_start_expr,
+                input_fetch_exprs=self.input_fetch_exprs,
+                is_first_level=True,
+                store_segment_start_flags=self.store_segment_start_flags,
+                k_group_size=k_group_size,
+                use_bank_conflict_avoidance=use_bank_conflict_avoidance)
+
+        candidate_scan_info = candidate_scan_gen_info.build(context, options)
+
+        # Will this device actually let us execute this kernel
+        # at the desired work group size? Building it is the
+        # only way to find out.
+        kernel_max_wg_size = min(
+                candidate_scan_info.kernel.get_work_group_info(
+                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
+                    dev)
+                for dev in devices)
+
+        if candidate_scan_info.wg_size <= kernel_max_wg_size:
+            break
+        else:
+            max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
+
+        trip_count += 1
+        assert trip_count <= 20
+
+    first_level_scan_gen_info = candidate_scan_gen_info
+    assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
+            == candidate_scan_info.wg_size)
+
+    # }}}
+
+    # {{{ build second-level scan
+
+    from pyopencl.tools import VectorArg
+    second_level_arguments = self.parsed_args + [
+            VectorArg(dtype, "interval_sums")]
+
+    second_level_build_kwargs = {}
+    if is_segmented:
+        second_level_arguments.append(
+                VectorArg(index_dtype,
+                    "g_first_segment_start_in_interval_input"))
+
+        # is_segment_start_expr answers the question "should previous sums
+        # spill over into this item". And since
+        # g_first_segment_start_in_interval_input answers the question if a
+        # segment boundary was found in an interval of data, then if not,
+        # it's ok to spill over.
+        second_level_build_kwargs["is_segment_start_expr"] = \
+                "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
+    else:
+        second_level_build_kwargs["is_segment_start_expr"] = None
+
+    second_level_scan_gen_info = _generate_scan_kernel(
+            max_scan_wg_size,
+            arguments=second_level_arguments,
+            input_expr="interval_sums[i]",
+            input_fetch_exprs=[],
+            is_first_level=False,
+            store_segment_start_flags=False,
+            k_group_size=k_group_size,
+            use_bank_conflict_avoidance=use_bank_conflict_avoidance,
+            **second_level_build_kwargs)
+
+    # }}}
+
+    # {{{ generate final update kernel
+
+    update_wg_size = min(max_scan_wg_size, 256)
+
+    final_update_tpl = _make_template(UPDATE_SOURCE)
+    final_update_src = str(final_update_tpl.render(
+        wg_size=update_wg_size,
+        output_statement=self.output_statement,
+        argument_signature=", ".join(
+            arg.declarator() for arg in self.parsed_args),
+        is_segment_start_expr=self.is_segment_start_expr,
+        input_expr=_process_code_for_macro(self.input_expr),
+        use_lookbehind_update=self.use_lookbehind_update,
+        **self.code_variables))
+
+    update_scalar_arg_dtypes = (
+            get_arg_list_scalar_arg_dtypes(self.parsed_args)
+            + [self.index_dtype, self.index_dtype, None, None])
+    if self.is_segmented:
+        # g_first_segment_start_in_interval
+        update_scalar_arg_dtypes.append(None)
+    if self.store_segment_start_flags:
+        update_scalar_arg_dtypes.append(None)  # g_segment_start_flags
+
+    self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
+            final_update_src,
+            self.name_prefix + "_final_update",
+            update_scalar_arg_dtypes,
+            update_wg_size)
+
+    # }}}
+
+
+def generate_scan_code_cached(context, options):
+    devices = context.devices
+    # Before generating the kernel, see if it's cached.
+    from pyopencl.cache import get_device_cache_id
+    devices_key = tuple(get_device_cache_id(device)
+            for device in devices)
+
+    cache_key = (kernel_key, devices_key)
+
+    from_cache = False
+
+    try:
+        result = generic_scan_kernel_cache[cache_key]
+        from_cache = True
+        logger.debug(
+                "cache hit for generated scan kernel '%s'" % self.name_prefix)
+        (
+            first_level_scan_gen_info,
+            second_level_scan_gen_info,
+            final_update_gen_info) = result
+    except KeyError:
+        pass
+
+    if not from_cache:
+        logger.debug(
+                "cache miss for generated scan kernel '%s'" % self.name_prefix)
+
+        (first_level_scan_gen_info,
+                  second_level_scan_gen_info,
+                  final_update_gen_info) = _generate_scan_code()
+
+        generic_scan_kernel_cache[cache_key] = result
+
+    # Build the kernels.
+    return(
+            first_level_scan_gen_info.build(context, options),
+            second_level_scan_gen_info.build(context, options),
+            final_update_gen_info.build(context, options))
+
+
+# }}}
+
+
 class ScanPerformanceWarning(UserWarning):
     pass
 
 
-class _GenericScanKernelBase(object):
-    # {{{ constructor, argument processing
+# {{{ driver base class
 
+class _GenericScanKernelBase(object):
     def __init__(self, ctx, dtype,
             arguments, input_expr, scan_expr, neutral, output_statement,
             is_segment_start_expr=None, input_fetch_exprs=[],
@@ -1127,14 +1419,7 @@ class _GenericScanKernelBase(object):
         self.store_segment_start_flags = (
                 self.is_segmented and self.use_lookbehind_update)
 
-        self.finish_setup()
-
-    # }}}
-
-
-generic_scan_kernel_cache = PersistentDict(
-        "pyopencl-generated-scan-kernel-cache-v1",
-        key_builder=_NumpyTypesKeyBuilder())
+# }}}
 
 
 class GenericScanKernel(_GenericScanKernelBase):
@@ -1156,289 +1441,7 @@ class GenericScanKernel(_GenericScanKernelBase):
 
     """
 
-    def finish_setup(self):
-        # Before generating the kernel, see if it's cached.
-        from pyopencl.cache import get_device_cache_id
-        devices_key = tuple(get_device_cache_id(device)
-                for device in self.devices)
-
-        cache_key = (self.kernel_key, devices_key)
-
-        from_cache = False
-
-        try:
-            result = generic_scan_kernel_cache[cache_key]
-            from_cache = True
-            logger.debug(
-                    "cache hit for generated scan kernel '%s'" % self.name_prefix)
-            (
-                self.first_level_scan_gen_info,
-                self.second_level_scan_gen_info,
-                self.final_update_gen_info) = result
-        except KeyError:
-            pass
-
-        if not from_cache:
-            logger.debug(
-                    "cache miss for generated scan kernel '%s'" % self.name_prefix)
-            self._finish_setup_impl()
-
-            result = (self.first_level_scan_gen_info,
-                      self.second_level_scan_gen_info,
-                      self.final_update_gen_info)
-
-            generic_scan_kernel_cache[cache_key] = result
-
-        # Build the kernels.
-        self.first_level_scan_info = self.first_level_scan_gen_info.build(
-                self.context, self.options)
-        del self.first_level_scan_gen_info
-
-        self.second_level_scan_info = self.second_level_scan_gen_info.build(
-                self.context, self.options)
-        del self.second_level_scan_gen_info
-
-        self.final_update_info = self.final_update_gen_info.build(
-                self.context, self.options)
-        del self.final_update_gen_info
-
-    def _finish_setup_impl(self):
-        # {{{ find usable workgroup/k-group size, build first-level scan
-
-        trip_count = 0
-
-        avail_local_mem = min(
-                dev.local_mem_size
-                for dev in self.devices)
-
-        if "CUDA" in self.devices[0].platform.name:
-            # not sure where these go, but roughly this much seems unavailable.
-            avail_local_mem -= 0x400
-
-        is_cpu = self.devices[0].type & cl.device_type.CPU
-        is_gpu = self.devices[0].type & cl.device_type.GPU
-
-        if is_cpu:
-            # (about the widest vector a CPU can support, also taking
-            # into account that CPUs don't hide latency by large work groups
-            max_scan_wg_size = 16
-            wg_size_multiples = 4
-        else:
-            max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
-            wg_size_multiples = 64
-
-        use_bank_conflict_avoidance = (
-                self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0 and is_gpu)
-
-        # k_group_size should be a power of two because of in-kernel
-        # division by that number.
-
-        solutions = []
-        for k_exp in range(0, 9):
-            for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
-                    wg_size_multiples):
-
-                k_group_size = 2**k_exp
-                lmem_use = self.get_local_mem_use(wg_size, k_group_size,
-                        use_bank_conflict_avoidance)
-                if lmem_use <= avail_local_mem:
-                    solutions.append((wg_size*k_group_size, k_group_size, wg_size))
-
-        if is_gpu:
-            from pytools import any
-            for wg_size_floor in [256, 192, 128]:
-                have_sol_above_floor = any(wg_size >= wg_size_floor
-                        for _, _, wg_size in solutions)
-
-                if have_sol_above_floor:
-                    # delete all solutions not meeting the wg size floor
-                    solutions = [(total, try_k_group_size, try_wg_size)
-                            for total, try_k_group_size, try_wg_size in solutions
-                            if try_wg_size >= wg_size_floor]
-                    break
-
-        _, k_group_size, max_scan_wg_size = max(solutions)
-
-        while True:
-            candidate_scan_gen_info = self.generate_scan_kernel(
-                    max_scan_wg_size, self.parsed_args,
-                    _process_code_for_macro(self.input_expr),
-                    self.is_segment_start_expr,
-                    input_fetch_exprs=self.input_fetch_exprs,
-                    is_first_level=True,
-                    store_segment_start_flags=self.store_segment_start_flags,
-                    k_group_size=k_group_size,
-                    use_bank_conflict_avoidance=use_bank_conflict_avoidance)
-
-            candidate_scan_info = candidate_scan_gen_info.build(
-                    self.context, self.options)
-
-            # Will this device actually let us execute this kernel
-            # at the desired work group size? Building it is the
-            # only way to find out.
-            kernel_max_wg_size = min(
-                    candidate_scan_info.kernel.get_work_group_info(
-                        cl.kernel_work_group_info.WORK_GROUP_SIZE,
-                        dev)
-                    for dev in self.devices)
-
-            if candidate_scan_info.wg_size <= kernel_max_wg_size:
-                break
-            else:
-                max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
-
-            trip_count += 1
-            assert trip_count <= 20
-
-        self.first_level_scan_gen_info = candidate_scan_gen_info
-        assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
-                == candidate_scan_info.wg_size)
-
-        # }}}
-
-        # {{{ build second-level scan
-
-        from pyopencl.tools import VectorArg
-        second_level_arguments = self.parsed_args + [
-                VectorArg(self.dtype, "interval_sums")]
-
-        second_level_build_kwargs = {}
-        if self.is_segmented:
-            second_level_arguments.append(
-                    VectorArg(self.index_dtype,
-                        "g_first_segment_start_in_interval_input"))
-
-            # is_segment_start_expr answers the question "should previous sums
-            # spill over into this item". And since
-            # g_first_segment_start_in_interval_input answers the question if a
-            # segment boundary was found in an interval of data, then if not,
-            # it's ok to spill over.
-            second_level_build_kwargs["is_segment_start_expr"] = \
-                    "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
-        else:
-            second_level_build_kwargs["is_segment_start_expr"] = None
-
-        self.second_level_scan_gen_info = self.generate_scan_kernel(
-                max_scan_wg_size,
-                arguments=second_level_arguments,
-                input_expr="interval_sums[i]",
-                input_fetch_exprs=[],
-                is_first_level=False,
-                store_segment_start_flags=False,
-                k_group_size=k_group_size,
-                use_bank_conflict_avoidance=use_bank_conflict_avoidance,
-                **second_level_build_kwargs)
-
-        # }}}
-
-        # {{{ generate final update kernel
-
-        update_wg_size = min(max_scan_wg_size, 256)
-
-        final_update_tpl = _make_template(UPDATE_SOURCE)
-        final_update_src = str(final_update_tpl.render(
-            wg_size=update_wg_size,
-            output_statement=self.output_statement,
-            argument_signature=", ".join(
-                arg.declarator() for arg in self.parsed_args),
-            is_segment_start_expr=self.is_segment_start_expr,
-            input_expr=_process_code_for_macro(self.input_expr),
-            use_lookbehind_update=self.use_lookbehind_update,
-            **self.code_variables))
-
-        update_scalar_arg_dtypes = (
-                get_arg_list_scalar_arg_dtypes(self.parsed_args)
-                + [self.index_dtype, self.index_dtype, None, None])
-        if self.is_segmented:
-            # g_first_segment_start_in_interval
-            update_scalar_arg_dtypes.append(None)
-        if self.store_segment_start_flags:
-            update_scalar_arg_dtypes.append(None)  # g_segment_start_flags
-
-        self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
-                final_update_src,
-                self.name_prefix + "_final_update",
-                update_scalar_arg_dtypes,
-                update_wg_size)
-
-        # }}}
-
     # {{{ scan kernel build/properties
-
-    def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance):
-        arg_dtypes = {}
-        for arg in self.parsed_args:
-            arg_dtypes[arg.name] = arg.dtype
-
-        fetch_expr_offsets = {}
-        for name, arg_name, ife_offset in self.input_fetch_exprs:
-            fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
-
-        itemsize = self.dtype.itemsize
-        if use_bank_conflict_avoidance:
-            itemsize += 4
-
-        return (
-                # ldata
-                itemsize*(k_group_size+1)*(wg_size+1)
-
-                # l_segment_start_flags
-                + k_group_size*wg_size
-
-                # l_first_segment_start_in_subtree
-                + self.index_dtype.itemsize*wg_size
-
-                + k_group_size*wg_size*sum(
-                    arg_dtypes[arg_name].itemsize
-                    for arg_name, ife_offsets in list(fetch_expr_offsets.items())
-                    if -1 in ife_offsets or len(ife_offsets) > 1))
-
-    def generate_scan_kernel(self, max_wg_size, arguments, input_expr,
-            is_segment_start_expr, input_fetch_exprs, is_first_level,
-            store_segment_start_flags, k_group_size,
-            use_bank_conflict_avoidance):
-        scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
-
-        # Empirically found on Nv hardware: no need to be bigger than this size
-        wg_size = _round_down_to_power_of_2(
-                min(max_wg_size, 256))
-
-        kernel_name = self.code_variables["name_prefix"]
-        if is_first_level:
-            kernel_name += "_lev1"
-        else:
-            kernel_name += "_lev2"
-
-        scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
-        scan_src = str(scan_tpl.render(
-            wg_size=wg_size,
-            input_expr=input_expr,
-            k_group_size=k_group_size,
-            argument_signature=", ".join(arg.declarator() for arg in arguments),
-            is_segment_start_expr=is_segment_start_expr,
-            input_fetch_exprs=input_fetch_exprs,
-            is_first_level=is_first_level,
-            store_segment_start_flags=store_segment_start_flags,
-            use_bank_conflict_avoidance=use_bank_conflict_avoidance,
-            kernel_name=kernel_name,
-            **self.code_variables))
-
-        scalar_arg_dtypes.extend(
-                (None, self.index_dtype, self.index_dtype))
-        if is_first_level:
-            scalar_arg_dtypes.append(None)  # interval_results
-        if self.is_segmented and is_first_level:
-            scalar_arg_dtypes.append(None)  # g_first_segment_start_in_interval
-        if store_segment_start_flags:
-            scalar_arg_dtypes.append(None)  # g_segment_start_flags
-
-        return _GeneratedScanKernelInfo(
-                scan_src=scan_src,
-                kernel_name=kernel_name,
-                scalar_arg_dtypes=scalar_arg_dtypes,
-                wg_size=wg_size,
-                k_group_size=k_group_size)
-
     # }}}
 
     def __call__(self, *args, **kwargs):
@@ -1623,6 +1626,10 @@ void ${name_prefix}_debug_scan(
 }
 """
 
+# }}}
+
+
+# {{{ debug driver
 
 class GenericDebugScanKernel(_GenericScanKernelBase):
     def finish_setup(self):