diff --git a/doc/source/array.rst b/doc/source/array.rst
index 8954f7a812da24257887d68efbef85a68fa3a20a..e2003ff8b3e2446ce92ee0ec92aa4ed14cdcbccf 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -604,6 +604,19 @@ Here's a usage example::
     assert (dev_data.get() == np.cumsum(host_data, axis=0)).all()
 
 
+Custom data types in Reduction and Scan
+---------------------------------------
+
+If you would like to use your own (struct/union/whatever) data types in
+scan and reduction, define those types in the *preamble* and let PyOpenCL
+know about them using this function:
+
+.. function:: pyopencl.tools.register_dtype(dtype, name)
+
+    *dtype* is a :func:`numpy.dtype`.
+
+    .. versionadded: 2011.2
+
 Fast Fourier Transforms
 -----------------------
 
diff --git a/doc/source/misc.rst b/doc/source/misc.rst
index 71143f9b3ea7737b54c3ecc700b602f624b888ad..8e71234c2fbcadaba3ba215feb9efb8eff88def5 100644
--- a/doc/source/misc.rst
+++ b/doc/source/misc.rst
@@ -90,6 +90,7 @@ Version 2011.2
 * Add :class:`pyopencl.NannyEvent` objects.
 * Add :mod:`pyopencl.characterize`.
 * Ensure compatibility with OS X Lion.
+* Add :func:`pyopencl.tools.register_dtype` to enable scan/reduction on struct types.
 
 .. * Beta support for OpenCL 1.2.
 
diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index 0a5aaf422217b070cf268df5a12f77b10e2e0d23..f8cc9876f7ce3abbc0f0fc2a56424fbac6749236 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -225,7 +225,11 @@ def get_reduction_kernel(stage,
         map_expr = "in[i]"
 
     if stage == 2:
-        arguments = "__global const %s *in" % out_type
+        in_arg = "__global const %s *in" % out_type
+        if arguments:
+            arguments = in_arg + ", " + arguments
+        else:
+            arguments = in_arg
 
     inf = get_reduction_source(
             ctx, out_type, out_type_size,
@@ -311,6 +315,8 @@ class ReductionKernel:
         if kwargs:
             raise TypeError("invalid keyword argument to reduction kernel")
 
+        stage1_args = args
+
         while True:
             invocation_args = []
             vectors = []
@@ -363,7 +369,7 @@ class ReductionKernel:
                 return result
             else:
                 stage_inf = self.stage_2_inf
-                args = [result]
+                args = (result,) + stage1_args