From 5d983b09f87b27e6474f21f6d04d38f11550149c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 24 Jul 2013 20:18:20 -0500
Subject: [PATCH] Fix lmem usage estimate in scan, parametrize scan test

---
 pyopencl/scan.py       | 31 +++++++++++++++++++++-------
 test/test_algorithm.py | 47 +++++++++++++++++++-----------------------
 2 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index ae380ff2..41dd7912 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -143,7 +143,7 @@ void ${name_prefix}_scan_intervals(
     )
 {
     // index K in first dimension used for carry storage
-    %if scan_dtype.itemsize > 4 and scan_dtype.itemsize % 8 == 0 and is_gpu:
+    %if use_bank_conflict_avoidance:
         // Avoid bank conflicts by adding a single 32-bit value to the size of
         // the scan type.
         struct __attribute__ ((__packed__)) wrapped_scan_type
@@ -1064,7 +1064,10 @@ class GenericScanKernel(_GenericScanKernelBase):
                 dev.local_mem_size
                 for dev in self.devices)
 
-        if self.devices[0].type == cl.device_type.CPU:
+        is_cpu = self.devices[0].type & cl.device_type.CPU
+        is_gpu = self.devices[0].type & cl.device_type.GPU
+
+        if is_cpu:
             # (about the widest vector a CPU can support, also taking
             # into account that CPUs don't hide latency by large work groups
             max_scan_wg_size = 16
@@ -1073,6 +1076,9 @@ class GenericScanKernel(_GenericScanKernelBase):
             max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
             wg_size_multiples = 64
 
+        use_bank_conflict_avoidance = (
+                self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0 and is_gpu)
+
         # k_group_size should be a power of two because of in-kernel
         # division by that number.
 
@@ -1082,11 +1088,12 @@ class GenericScanKernel(_GenericScanKernelBase):
                     wg_size_multiples):
 
                 k_group_size = 2**k_exp
-                lmem_use = self.get_local_mem_use(wg_size, k_group_size)
+                lmem_use = self.get_local_mem_use(wg_size, k_group_size,
+                        use_bank_conflict_avoidance)
                 if lmem_use + 256 <= avail_local_mem:
                     solutions.append((wg_size*k_group_size, k_group_size, wg_size))
 
-        if self.devices[0].type & cl.device_type.GPU:
+        if is_gpu:
             from pytools import any
             for wg_size_floor in [256, 192, 128]:
                 have_sol_above_floor = any(wg_size >= wg_size_floor
@@ -1109,7 +1116,8 @@ class GenericScanKernel(_GenericScanKernelBase):
                     input_fetch_exprs=self.input_fetch_exprs,
                     is_first_level=True,
                     store_segment_start_flags=self.store_segment_start_flags,
-                    k_group_size=k_group_size)
+                    k_group_size=k_group_size,
+                    use_bank_conflict_avoidance=use_bank_conflict_avoidance)
 
             # Will this device actually let us execute this kernel
             # at the desired work group size? Building it is the
@@ -1164,6 +1172,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                 is_first_level=False,
                 store_segment_start_flags=False,
                 k_group_size=k_group_size,
+                use_bank_conflict_avoidance=use_bank_conflict_avoidance,
                 **second_level_build_kwargs)
 
         # }}}
@@ -1202,7 +1211,7 @@ class GenericScanKernel(_GenericScanKernelBase):
 
     # {{{ scan kernel build/properties
 
-    def get_local_mem_use(self, k_group_size, wg_size):
+    def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance):
         arg_dtypes = {}
         for arg in self.parsed_args:
             arg_dtypes[arg.name] = arg.dtype
@@ -1211,9 +1220,13 @@ class GenericScanKernel(_GenericScanKernelBase):
         for name, arg_name, ife_offset in self.input_fetch_exprs:
             fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
 
+        itemsize = self.dtype.itemsize
+        if use_bank_conflict_avoidance:
+            itemsize += 4
+
         return (
                 # ldata
-                self.dtype.itemsize*(k_group_size+1)*(wg_size+1)
+                itemsize*(k_group_size+1)*(wg_size+1)
 
                 # l_segment_start_flags
                 + k_group_size*wg_size
@@ -1228,7 +1241,8 @@ class GenericScanKernel(_GenericScanKernelBase):
 
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
             is_segment_start_expr, input_fetch_exprs, is_first_level,
-            store_segment_start_flags, k_group_size):
+            store_segment_start_flags, k_group_size,
+            use_bank_conflict_avoidance):
         scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
 
         # Empirically found on Nv hardware: no need to be bigger than this size
@@ -1245,6 +1259,7 @@ class GenericScanKernel(_GenericScanKernelBase):
             input_fetch_exprs=input_fetch_exprs,
             is_first_level=is_first_level,
             store_segment_start_flags=store_segment_start_flags,
+            use_bank_conflict_avoidance=use_bank_conflict_avoidance,
             **self.code_variables))
 
         prg = cl.Program(self.context, scan_src).build(self.options)
diff --git a/test/test_algorithm.py b/test/test_algorithm.py
index bcc687e2..7f0f9f4c 100644
--- a/test/test_algorithm.py
+++ b/test/test_algorithm.py
@@ -486,44 +486,39 @@ scan_test_counts = [
     ]
 
 
-def test_scan(ctx_factory):
+@pytest.mark.parametrize("dtype", [np.int32, np.int64])
+@pytest.mark.parametrize("scan_cls", [InclusiveScanKernel, ExclusiveScanKernel])
+def test_scan(ctx_factory, dtype, scan_cls):
     from pytest import importorskip
     importorskip("mako")
 
     context = ctx_factory()
     queue = cl.CommandQueue(context)
 
-    from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel
+    knl = scan_cls(context, dtype, "a+b", "0")
 
-    dtype = np.int32
-    for cls in [
-            InclusiveScanKernel,
-            ExclusiveScanKernel
-            ]:
-        knl = cls(context, dtype, "a+b", "0")
-
-        for n in scan_test_counts:
-            host_data = np.random.randint(0, 10, n).astype(dtype)
-            dev_data = cl_array.to_device(queue, host_data)
+    for n in scan_test_counts:
+        host_data = np.random.randint(0, 10, n).astype(dtype)
+        dev_data = cl_array.to_device(queue, host_data)
 
-            # /!\ fails on Nv GT2?? for some drivers
-            assert (host_data == dev_data.get()).all()
+        # /!\ fails on Nv GT2?? for some drivers
+        assert (host_data == dev_data.get()).all()
 
-            knl(dev_data)
+        knl(dev_data)
 
-            desired_result = np.cumsum(host_data, axis=0)
-            if cls is ExclusiveScanKernel:
-                desired_result -= host_data
+        desired_result = np.cumsum(host_data, axis=0)
+        if scan_cls is ExclusiveScanKernel:
+            desired_result -= host_data
 
-            is_ok = (dev_data.get() == desired_result).all()
-            if 1 and not is_ok:
-                print("something went wrong, summarizing error...")
-                print(summarize_error(dev_data.get(), desired_result, host_data))
+        is_ok = (dev_data.get() == desired_result).all()
+        if 1 and not is_ok:
+            print("something went wrong, summarizing error...")
+            print(summarize_error(dev_data.get(), desired_result, host_data))
 
-            print("n:%d %s worked:%s" % (n, cls, is_ok))
-            assert is_ok
-            from gc import collect
-            collect()
+        print("dtype:%s n:%d %s worked:%s" % (dtype, n, scan_cls, is_ok))
+        assert is_ok
+        from gc import collect
+        collect()
 
 
 def test_copy_if(ctx_factory):
-- 
GitLab