From 405c789e8a167404f3d3777ced7fd877d646cc68 Mon Sep 17 00:00:00 2001
From: Lucas C Wilcox <lucas@swirlee.com>
Date: Thu, 7 Jan 2016 19:34:22 -0600
Subject: [PATCH] Pass kernel to get_{global,local}_axis_expr, use signed
 indices on CUDA

---
 loopy/codegen/loop.py    |  4 ++--
 loopy/target/__init__.py |  4 ++--
 loopy/target/cuda.py     | 19 +++++++++++++++----
 loopy/target/ispc.py     |  4 ++--
 loopy/target/opencl.py   |  4 ++--
 5 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 46bf0aa24..98e12cca7 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -244,9 +244,9 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state,
 
     assert isinstance(tag, UniqueTag)
     if isinstance(tag, LocalIndexTag):
-        hw_axis_expr = kernel.target.get_local_axis_expr(tag.axis)
+        hw_axis_expr = kernel.target.get_local_axis_expr(kernel, tag.axis)
     elif isinstance(tag, GroupIndexTag):
-        hw_axis_expr = kernel.target.get_global_axis_expr(tag.axis)
+        hw_axis_expr = kernel.target.get_global_axis_expr(kernel, tag.axis)
     else:
         raise RuntimeError("unexpected hw tag type")
 
diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py
index ba83b07ee..5b51808e0 100644
--- a/loopy/target/__init__.py
+++ b/loopy/target/__init__.py
@@ -100,10 +100,10 @@ class TargetBase(object):
     def get_expression_to_code_mapper(self, codegen_state):
         raise NotImplementedError()
 
-    def get_global_axis_expr(self, axis):
+    def get_global_axis_expr(self, kernel, axis):
         raise NotImplementedError()
 
-    def get_local_axis_expr(self, axis):
+    def get_local_axis_expr(self, kernel, axis):
         raise NotImplementedError()
 
     def add_vector_access(self, access_str, index):
diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py
index 6f66e7b99..efe755dca 100644
--- a/loopy/target/cuda.py
+++ b/loopy/target/cuda.py
@@ -200,11 +200,22 @@ class CudaTarget(CTarget):
 
     _GRID_AXES = "xyz"
 
-    def get_global_axis_expr(self, axis):
-        return var("blockIdx").attr(self._GRID_AXES[axis])
+    @staticmethod
+    def _get_index_ctype(kernel):
+        if kernel.index_dtype == np.int32:
+            return "int32_t"
+        else:
+            return "int64_t"
+
+    def get_global_axis_expr(self, kernel, axis):
+        return var("((%s) blockIdx.%s)" % (
+            self._get_index_ctype(kernel),
+            self._GRID_AXES[axis]))
 
-    def get_local_axis_expr(self, axis):
-        return var("threadIdx").attr(self._GRID_AXES[axis])
+    def get_local_axis_expr(self, kernel, axis):
+        return var("((%s) threadIdx.%s)" % (
+            self._get_index_ctype(kernel),
+            self._GRID_AXES[axis]))
 
     _VEC_AXES = "xyzw"
 
diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py
index e0c4b75a3..cf11092df 100644
--- a/loopy/target/ispc.py
+++ b/loopy/target/ispc.py
@@ -101,10 +101,10 @@ class ISPCTarget(CTarget):
 
     # {{{ code generation guts
 
-    def get_global_axis_expr(self, axis):
+    def get_global_axis_expr(self, kernel, axis):
         return var("taskIndex%d" % axis)
 
-    def get_local_axis_expr(self, axis):
+    def get_local_axis_expr(self, kernel, axis):
         if axis == 0:
             return var("programIndex")
         else:
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 4b88e6555..cf130a095 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -267,10 +267,10 @@ class OpenCLTarget(CTarget):
 
     # {{{ code generation guts
 
-    def get_global_axis_expr(self, axis):
+    def get_global_axis_expr(self, kernel, axis):
         return var("gid")(axis)
 
-    def get_local_axis_expr(self, axis):
+    def get_local_axis_expr(self, kernel, axis):
         return var("lid")(axis)
 
     def add_vector_access(self, access_str, index):
-- 
GitLab