From c671e9388f5cc6ecb6428021911ea9f4936fdbb8 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 Apr 2011 21:34:18 -0400
Subject: [PATCH] Pull in pycuda scan changes.

---
 pyopencl/scan.py | 41 +++++++++--------------------------------
 1 file changed, 9 insertions(+), 32 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 02679a31..ee6f74fe 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -364,31 +364,6 @@ void ${name_prefix}_final_update(
 
 
 
-def _div_ceil(nr, dr):
-    return (nr + dr -1) // dr
-
-
-def _uniform_interval_splitting(n, granularity, max_intervals):
-    grains  = _div_ceil(n, granularity)
-
-    # one grain per interval
-    if grains <= max_intervals:
-        return granularity, grains
-
-    # ensures that:
-    #     num_intervals * interval_size is >= n
-    #   and
-    #     (num_intervals - 1) * interval_size is < n
-
-    grains_per_interval = _div_ceil(grains, max_intervals)
-    interval_size = grains_per_interval * granularity
-    num_intervals = _div_ceil(n, interval_size)
-
-    return interval_size, num_intervals
-
-
-
-
 if _CL_MODE:
     class _ScanKernelBase(object):
         def __init__(self, ctx, dtype,
@@ -475,7 +450,8 @@ if _CL_MODE:
             unit_size  = self.scan_wg_size * self.scan_wg_seq_batches
             max_groups = 3*max(dev.max_compute_units for dev in self.devices)
 
-            interval_size, num_groups = _uniform_interval_splitting(
+            from pytools import uniform_interval_splitting
+            interval_size, num_groups = uniform_interval_splitting(
                     n, unit_size, max_groups);
 
             block_results = allocator(self.dtype.itemsize*num_groups)
@@ -541,7 +517,7 @@ else:
                     scan_intervals_src, options=options, no_extern_c=True)
             self.scan_intervals_knl = scan_intervals_prg.get_function(
                     name_prefix+"_scan_intervals")
-            self.scan_intervals_knl.prepare("PIIPP", (self.scan_wg_size, 1, 1))
+            self.scan_intervals_knl.prepare("PIIPP")
 
             final_update_src = str(self.final_update_tp.render(
                 wg_size=self.update_wg_size,
@@ -551,7 +527,7 @@ else:
                     final_update_src, options=options, no_extern_c=True)
             self.final_update_knl = final_update_prg.get_function(
                     name_prefix+"_final_update")
-            self.final_update_knl.prepare("PIIP", (self.update_wg_size, 1, 1))
+            self.final_update_knl.prepare("PIIP")
 
         def __call__(self, input_ary, output_ary=None, allocator=None,
                 stream=None):
@@ -576,7 +552,8 @@ else:
             max_groups = 3*dev.get_attribute(
                     driver.device_attribute.MULTIPROCESSOR_COUNT)
 
-            interval_size, num_groups = _uniform_interval_splitting(
+            from pytools import uniform_interval_splitting
+            interval_size, num_groups = uniform_interval_splitting(
                     n, unit_size, max_groups);
 
             block_results = allocator(self.dtype.itemsize*num_groups)
@@ -584,7 +561,7 @@ else:
 
             # first level scan of interval (one interval per block)
             self.scan_intervals_knl.prepared_async_call(
-                    (num_groups, 1), stream,
+                    (num_groups, 1), (self.scan_wg_size, 1, 1), stream,
                     input_ary.gpudata,
                     n, interval_size,
                     output_ary.gpudata,
@@ -592,7 +569,7 @@ else:
 
             # second level inclusive scan of per-block results
             self.scan_intervals_knl.prepared_async_call(
-                    (1, 1), stream,
+                    (1,1), (self.scan_wg_size, 1, 1), stream,
                     block_results,
                     num_groups, interval_size,
                     block_results,
@@ -600,7 +577,7 @@ else:
 
             # update intervals with result of second level scan
             self.final_update_knl.prepared_async_call(
-                    (num_groups, 1,), stream,
+                    (num_groups, 1,), (self.update_wg_size, 1, 1), stream,
                     output_ary.gpudata,
                     n, interval_size,
                     block_results)
-- 
GitLab