diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index fa1012d8b694dbf6096aa817b201c043a7fb5a9c..1701e16843f776200c54c0cd6c0f1ebc1d5ba56b 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -1,3 +1,7 @@
+# WARNING!
+# If you update this file, make sure to also update the sister copy in
+# PyCUDA or PyOpenCL--both files should always be exactly identical.
+
 """Scan primitive."""
 
 from __future__ import division
@@ -27,36 +31,23 @@ within the Thrust project, https://code.google.com/p/thrust/
 
 
 
-import pyopencl as cl
-import pyopencl.array as cl_array
-from pyopencl.tools import dtype_to_ctype
 import numpy as np
-import pyopencl._mymako as mako
-
-
-
 
-CLUDA_PREAMBLE = """
-#define local_barrier() barrier(CLK_LOCAL_MEM_FENCE);
+_CL_MODE = "pyopencl" in __name__
 
-#define WITHIN_KERNEL /* empty */
-#define KERNEL __kernel
-#define GLOBAL_MEM __global
-#define LOCAL_MEM __local
-#define REQD_WG_SIZE(X,Y,Z) __attribute__((reqd_work_group_size(X, Y, Z)))
-
-#define LID_0 get_local_id(0)
-#define LID_1 get_local_id(1)
-#define LID_2 get_local_id(2)
-
-#define GID_0 get_group_id(0)
-#define GID_1 get_group_id(1)
-#define GID_2 get_group_id(2)
-
-% if double_support:
-    #pragma OPENCL EXTENSION cl_khr_fp64: enable
-% endif
-"""
+if _CL_MODE:
+    import pyopencl as cl
+    import pyopencl.array as cl_array
+    from pyopencl.tools import dtype_to_ctype
+    import pyopencl._mymako as mako
+    from pyopencl._cluda import CLUDA_PREAMBLE
+else:
+    import pycuda.driver as driver
+    import pycuda.gpuarray as gpuarray
+    from pycuda.compiler import SourceModule
+    from pycuda.tools import dtype_to_ctype
+    import pycuda._mymako as mako
+    from pycuda._cluda import CLUDA_PREAMBLE
 
 
 
@@ -78,7 +69,7 @@ SCAN_INTERVALS_SOURCE = mako.template.Template(SHARED_PREAMBLE + """
 
 <%def name="make_group_scan(name, with_bounds_check)">
     WITHIN_KERNEL
-    void ${name}(LOCAL_MEM scan_type *array
+    void ${name}(LOCAL_MEM_ARG scan_type *array
     % if with_bounds_check:
       , const unsigned n
     % endif
@@ -398,119 +389,218 @@ def _uniform_interval_splitting(n, granularity, max_intervals):
 
 
 
-class _ScanKernelBase(object):
-    def __init__(self, ctx, dtype,
-            scan_expr, neutral=None,
-            name_prefix="scan", options="", preamble="", devices=None):
+if _CL_MODE:
+    class _ScanKernelBase(object):
+        def __init__(self, ctx, dtype,
+                scan_expr, neutral=None,
+                name_prefix="scan", options="", preamble="", devices=None):
+
+            self.context = ctx
+            dtype = self.dtype = np.dtype(dtype)
+            self.neutral = neutral
+
+            if devices is None:
+                devices = ctx.devices
+            self.devices = devices
+
+            max_wg_size = min(dev.max_work_group_size for dev in self.devices)
+
+            # Thrust says these are good for GT200
+            self.scan_wg_size = min(max_wg_size, 128)
+            self.update_wg_size = min(max_wg_size, 256)
+
+            if self.scan_wg_size < 16:
+                # Hello, Apple CPU. Nice to see you.
+                self.scan_wg_seq_batches = 128 # FIXME: guesswork
+            else:
+                self.scan_wg_seq_batches = 6
+
+            from pytools import all
+            from pyopencl.characterize import has_double_support
+
+            kw_values = dict(
+                preamble=preamble,
+                name_prefix=name_prefix,
+                scan_type=dtype_to_ctype(dtype),
+                scan_expr=scan_expr,
+                neutral=neutral,
+                double_support=all(
+                    has_double_support(dev) for dev in devices)
+                )
+
+            scan_intervals_src = str(SCAN_INTERVALS_SOURCE.render(
+                wg_size=self.scan_wg_size,
+                wg_seq_batches=self.scan_wg_seq_batches,
+                **kw_values))
+            scan_intervals_prg = cl.Program(ctx, scan_intervals_src).build(options)
+            self.scan_intervals_knl = getattr(
+                    scan_intervals_prg,
+                    name_prefix+"_scan_intervals")
+            self.scan_intervals_knl.set_scalar_arg_dtypes(
+                    (None, np.uint32, np.uint32, None, None))
+
+            final_update_src = str(self.final_update_tp.render(
+                wg_size=self.update_wg_size,
+                **kw_values))
+
+            final_update_prg = cl.Program(self.context, final_update_src).build(options)
+            self.final_update_knl = getattr(
+                    final_update_prg,
+                    name_prefix+"_final_update")
+            self.final_update_knl.set_scalar_arg_dtypes(
+                    (None, np.uint32, np.uint32, None))
+
+        def __call__(self, input_ary, output_ary=None, allocator=None,
+                queue=None):
+            allocator = allocator or input_ary.allocator
+            queue = queue or input_ary.queue or output_ary.queue
+
+            if output_ary is None:
+                output_ary = input_ary
+
+            if isinstance(output_ary, (str, unicode)) and output_ary == "new":
+                output_ary = cl_array.empty_like(input_ary, allocator=allocator)
+
+            if input_ary.shape != output_ary.shape:
+                raise ValueError("input and output must have the same shape")
+
+            n, = input_ary.shape
+
+            if not n:
+                return output_ary
+
+            unit_size  = self.scan_wg_size * self.scan_wg_seq_batches
+            max_groups = 3*max(dev.max_compute_units for dev in self.devices)
+
+            interval_size, num_groups = _uniform_interval_splitting(
+                    n, unit_size, max_groups);
+
+            block_results = allocator(self.dtype.itemsize*num_groups)
+            dummy_results = allocator(self.dtype.itemsize)
+
+            # first level scan of interval (one interval per block)
+            self.scan_intervals_knl(
+                    queue, (num_groups*self.scan_wg_size,), (self.scan_wg_size,),
+                    input_ary.data,
+                    n, interval_size,
+                    output_ary.data,
+                    block_results)
+
+            # second level inclusive scan of per-block results
+            self.scan_intervals_knl(
+                    queue, (self.scan_wg_size,), (self.scan_wg_size,),
+                    block_results,
+                    num_groups, interval_size,
+                    block_results,
+                    dummy_results)
+
+            # update intervals with result of second level scan
+            self.final_update_knl(
+                    queue, (num_groups*self.update_wg_size,), (self.update_wg_size,),
+                    output_ary.data,
+                    n, interval_size,
+                    block_results)
+
+            return output_ary
+
 
-        self.context = ctx
-        dtype = self.dtype = np.dtype(dtype)
-        self.neutral = neutral
 
-        if devices is None:
-            devices = ctx.devices
-        self.devices = devices
 
-        max_wg_size = min(dev.max_work_group_size for dev in self.devices)
+else:
+    class _ScanKernelBase(object):
+        def __init__(self, dtype,
+                scan_expr, neutral=None,
+                name_prefix="scan", options=[], preamble="", devices=None):
 
-        # Thrust says these are good for GT200
-        self.scan_wg_size = min(max_wg_size, 128)
-        self.update_wg_size = min(max_wg_size, 256)
+            dtype = self.dtype = np.dtype(dtype)
+            self.neutral = neutral
 
-        if self.scan_wg_size < 16:
-            # Hello, Apple CPU. Nice to see you.
-            self.scan_wg_seq_batches = 128 # FIXME: guesswork
-        else:
+            # Thrust says these are good for GT200
+            self.scan_wg_size = 128
+            self.update_wg_size = 256
             self.scan_wg_seq_batches = 6
 
-        from pytools import all
-        from pyopencl.characterize import has_double_support
-
-        kw_values = dict(
-            preamble=preamble,
-            name_prefix=name_prefix,
-            scan_type=dtype_to_ctype(dtype),
-            scan_expr=scan_expr,
-            neutral=neutral,
-            double_support=all(
-                has_double_support(dev) for dev in devices)
-            )
+            kw_values = dict(
+                preamble=preamble,
+                name_prefix=name_prefix,
+                scan_type=dtype_to_ctype(dtype),
+                scan_expr=scan_expr,
+                neutral=neutral)
+
+            scan_intervals_src = str(SCAN_INTERVALS_SOURCE.render(
+                wg_size=self.scan_wg_size,
+                wg_seq_batches=self.scan_wg_seq_batches,
+                **kw_values))
+            scan_intervals_prg = SourceModule(
+                    scan_intervals_src, options=options, no_extern_c=True)
+            self.scan_intervals_knl = scan_intervals_prg.get_function(
+                    name_prefix+"_scan_intervals")
+            self.scan_intervals_knl.prepare("PIIPP", (self.scan_wg_size, 1, 1))
+
+            final_update_src = str(self.final_update_tp.render(
+                wg_size=self.update_wg_size,
+                **kw_values))
+
+            final_update_prg = SourceModule(
+                    final_update_src, options=options, no_extern_c=True)
+            self.final_update_knl = final_update_prg.get_function(
+                    name_prefix+"_final_update")
+            self.final_update_knl.prepare("PIIP", (self.update_wg_size, 1, 1))
+
+        def __call__(self, input_ary, output_ary=None, allocator=None,
+                stream=None):
+            allocator = allocator or input_ary.allocator
+
+            if output_ary is None:
+                output_ary = input_ary
+
+            if isinstance(output_ary, (str, unicode)) and output_ary == "new":
+                output_ary = cl_array.empty_like(input_ary, allocator=allocator)
+
+            if input_ary.shape != output_ary.shape:
+                raise ValueError("input and output must have the same shape")
+
+            n, = input_ary.shape
+
+            if not n:
+                return output_ary
+
+            unit_size  = self.scan_wg_size * self.scan_wg_seq_batches
+            dev = driver.Context.get_device()
+            max_groups = 3*dev.get_attribute(
+                    driver.device_attribute.MULTIPROCESSOR_COUNT)
+
+            interval_size, num_groups = _uniform_interval_splitting(
+                    n, unit_size, max_groups);
+
+            block_results = allocator(self.dtype.itemsize*num_groups)
+            dummy_results = allocator(self.dtype.itemsize)
+
+            # first level scan of interval (one interval per block)
+            self.scan_intervals_knl.prepared_async_call(
+                    (num_groups, 1), stream,
+                    input_ary.gpudata,
+                    n, interval_size,
+                    output_ary.gpudata,
+                    block_results)
+
+            # second level inclusive scan of per-block results
+            self.scan_intervals_knl.prepared_async_call(
+                    (1, 1), stream,
+                    block_results,
+                    num_groups, interval_size,
+                    block_results,
+                    dummy_results)
+
+            # update intervals with result of second level scan
+            self.final_update_knl.prepared_async_call(
+                    (num_groups, 1,), stream,
+                    output_ary.gpudata,
+                    n, interval_size,
+                    block_results)
 
-        scan_intervals_src = str(SCAN_INTERVALS_SOURCE.render(
-            wg_size=self.scan_wg_size,
-            wg_seq_batches=self.scan_wg_seq_batches,
-            **kw_values))
-        scan_intervals_prg = cl.Program(ctx, scan_intervals_src).build(options)
-        self.scan_intervals_knl = getattr(
-                scan_intervals_prg,
-                name_prefix+"_scan_intervals")
-        self.scan_intervals_knl.set_scalar_arg_dtypes(
-                (None, np.uint32, np.uint32, None, None))
-
-        final_update_src = str(self.final_update_tp.render(
-            wg_size=self.update_wg_size,
-            **kw_values))
-
-        final_update_prg = cl.Program(self.context, final_update_src).build(options)
-        self.final_update_knl = getattr(
-                final_update_prg,
-                name_prefix+"_final_update")
-        self.final_update_knl.set_scalar_arg_dtypes(
-                (None, np.uint32, np.uint32, None))
-
-    def __call__(self, input_ary, output_ary=None, allocator=None,
-            queue=None):
-        allocator = allocator or input_ary.allocator
-        queue = queue or input_ary.queue or output_ary.queue
-
-        if output_ary is None:
-            output_ary = input_ary
-
-        if isinstance(output_ary, (str, unicode)) and output_ary == "new":
-            output_ary = cl_array.empty_like(input_ary, allocator=allocator)
-
-        if input_ary.shape != output_ary.shape:
-            raise ValueError("input and output must have the same shape")
-
-        n, = input_ary.shape
-
-        if not n:
             return output_ary
 
-        unit_size  = self.scan_wg_size * self.scan_wg_seq_batches
-        max_groups = 3*max(dev.max_compute_units for dev in self.devices)
-
-        interval_size, num_groups = _uniform_interval_splitting(
-                n, unit_size, max_groups);
-
-        block_results = allocator(self.dtype.itemsize*num_groups)
-        dummy_results = allocator(self.dtype.itemsize)
-
-        # first level scan of interval (one interval per block)
-        self.scan_intervals_knl(
-                queue, (num_groups*self.scan_wg_size,), (self.scan_wg_size,),
-                input_ary.data,
-                n, interval_size,
-                output_ary.data,
-                block_results)
-
-        # second level inclusive scan of per-block results
-        self.scan_intervals_knl(
-                queue, (self.scan_wg_size,), (self.scan_wg_size,),
-                block_results,
-                num_groups, interval_size,
-                block_results,
-                dummy_results)
-
-        # update intervals with result of second level scan
-        self.final_update_knl(
-                queue, (num_groups*self.update_wg_size,), (self.update_wg_size,),
-                output_ary.data,
-                n, interval_size,
-                block_results)
-
-        return output_ary
-
 
 
 class InclusiveScanKernel(_ScanKernelBase):