From ce7efa3ff6d878df066dc35741e8ad5812787c7b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 23 Aug 2012 14:27:31 -0400
Subject: [PATCH] Make index_dtype settable.

---
 loopy/codegen/__init__.py   | 7 ++++---
 loopy/codegen/bounds.py     | 5 +++--
 loopy/codegen/expression.py | 2 +-
 loopy/codegen/loop.py       | 3 ++-
 loopy/kernel.py             | 6 ++++--
 5 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 49ccbf055..4d7047df4 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -263,11 +263,12 @@ def generate_code(kernel, with_annotation=False,
 
     # }}}
 
+    from pyopencl.tools import dtype_to_ctype
     mod.extend([
         LiteralLines(r"""
-        #define lid(N) ((int) get_local_id(N))
-        #define gid(N) ((int) get_group_id(N))
-        """),
+        #define lid(N) ((%(idx_ctype)s) get_local_id(N))
+        #define gid(N) ((%(idx_ctype)s) get_group_id(N))
+        """ % dict(idx_ctype=dtype_to_ctype(kernel.index_dtype))),
         Line()])
 
     # {{{ build lmem array declarators for temporary variables
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index c9a9b8660..5bc5c3586 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -146,7 +146,8 @@ def wrap_in_bounds_checks(ccm, domain, check_inames, implemented_domain, stmt):
 
     return stmt, new_implemented_domain
 
-def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
+def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt,
+        index_dtype):
     # FIXME add admissible vars
     if isinstance(constraint_bset, isl.Set):
         constraint_bset, = constraint_bset.get_basic_sets()
@@ -191,7 +192,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
         from loopy.codegen import gen_code_block
         from cgen import Initializer, POD, Const, Line
         return gen_code_block([
-            Initializer(Const(POD(np.int32, iname)),
+            Initializer(Const(POD(index_dtype, iname)),
                 ccm(equality_expr, 'i')),
             Line(),
             stmt,
diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 82741a249..964d7efc1 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -87,7 +87,7 @@ class TypeInferenceMapper(CombineMapper):
             return tv.dtype
 
         if expr.name in self.kernel.all_inames():
-            return np.dtype(np.int16) # don't force single-precision upcast
+            return self.kernel.index_dtype
 
         for mangler in self.kernel.symbol_manglers:
             result = mangler(expr.name)
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 36a24d5c5..7b4e63c95 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -259,7 +259,8 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
             from cgen import Comment
             result.append(Comment(cmt))
         result.append(
-                wrap_in_for_from_constraints(ccm, iname, slab, inner))
+                wrap_in_for_from_constraints(ccm, iname, slab, inner,
+                    kernel.index_dtype))
 
     return gen_code_block(result)
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index acbc01811..5e2a060e6 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -649,7 +649,8 @@ class LoopKernel(Record):
             applied_iname_rewrites=[],
             cache_manager=None,
             iname_to_tag_requests=None,
-            lowest_priority_inames=[], breakable_inames=set()):
+            lowest_priority_inames=[], breakable_inames=set(),
+            index_dtype=np.int32):
         """
         :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl.
             Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}"
@@ -895,7 +896,8 @@ class LoopKernel(Record):
                 breakable_inames=breakable_inames,
                 applied_iname_rewrites=applied_iname_rewrites,
                 function_manglers=function_manglers,
-                symbol_manglers=symbol_manglers)
+                symbol_manglers=symbol_manglers,
+                index_dtype=np.dtype(index_dtype))
 
     # {{{ function mangling
 
-- 
GitLab