From 0ebd02f1e67c4b5cc592f0145b8c5705f178397f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 7 Jan 2016 23:03:57 -0600
Subject: [PATCH] Introduce placeholders for hw axes, rather than using
 target-specific expressions

---
 loopy/codegen/loop.py                | 10 +++---
 loopy/expression.py                  |  6 ++++
 loopy/symbolic.py                    | 51 ++++++++++++++++++++++++++--
 loopy/target/__init__.py             |  6 ----
 loopy/target/c/codegen/expression.py |  6 ++++
 loopy/target/cuda.py                 | 48 +++++++++++++++-----------
 loopy/target/ispc.py                 | 26 +++++++++-----
 loopy/target/opencl.py               | 22 ++++++++----
 8 files changed, 127 insertions(+), 48 deletions(-)

diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 98e12cca7..6d0b2ca60 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -242,11 +242,13 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state,
 
     tag = kernel.iname_to_tag.get(iname)
 
+    from loopy.symbolic import GroupHardwareAxisIndex, LocalHardwareAxisIndex
+
     assert isinstance(tag, UniqueTag)
-    if isinstance(tag, LocalIndexTag):
-        hw_axis_expr = kernel.target.get_local_axis_expr(kernel, tag.axis)
-    elif isinstance(tag, GroupIndexTag):
-        hw_axis_expr = kernel.target.get_global_axis_expr(kernel, tag.axis)
+    if isinstance(tag, GroupIndexTag):
+        hw_axis_expr = GroupHardwareAxisIndex(tag.axis)
+    elif isinstance(tag, LocalIndexTag):
+        hw_axis_expr = LocalHardwareAxisIndex(tag.axis)
     else:
         raise RuntimeError("unexpected hw tag type")
 
diff --git a/loopy/expression.py b/loopy/expression.py
index e9f7d8410..62c2278be 100644
--- a/loopy/expression.py
+++ b/loopy/expression.py
@@ -257,6 +257,12 @@ class TypeInferenceMapper(CombineMapper):
     map_logical_and = map_comparison
     map_logical_or = map_comparison
 
+    def map_group_hw_index(self, expr, *args):
+        return self.kernel.index_dtype
+
+    def map_local_hw_index(self, expr, *args):
+        return self.kernel.index_dtype
+
     def map_reduction(self, expr):
         return expr.operation.result_dtype(
                 self.kernel.target, self.rec(expr.expr), expr.inames)
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index b3dfce3d6..7adab80c6 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -69,6 +69,15 @@ import numpy as np
 # {{{ mappers with support for loopy-specific primitives
 
 class IdentityMapperMixin(object):
+    def map_group_hw_index(self, expr, *args):
+        return expr
+
+    def map_local_hw_index(self, expr, *args):
+        return expr
+
+    def map_loopy_function_identifier(self, expr, *args):
+        return expr
+
     def map_reduction(self, expr, *args):
         return Reduction(expr.operation, expr.inames, self.rec(expr.expr, *args))
 
@@ -76,9 +85,6 @@ class IdentityMapperMixin(object):
         # leaf, doesn't change
         return expr
 
-    def map_loopy_function_identifier(self, expr, *args):
-        return expr
-
     map_linear_subscript = IdentityMapperBase.map_subscript
 
 
@@ -92,6 +98,12 @@ class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin):
 
 
 class WalkMapper(WalkMapperBase):
+    def map_group_hw_index(self, expr, *args):
+        self.visit(expr)
+
+    def map_local_hw_index(self, expr, *args):
+        self.visit(expr)
+
     def map_reduction(self, expr, *args):
         if not self.visit(expr):
             return
@@ -127,6 +139,12 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,
 
 
 class StringifyMapper(StringifyMapperBase):
+    def map_group_hw_index(self, expr, enclosing_prec):
+        return "grp.%d" % expr.index
+
+    def map_local_hw_index(self, expr, enclosing_prec):
+        return "loc.%d" % expr.index
+
     def map_reduction(self, expr, prec):
         return "reduce(%s, [%s], %s)" % (
                 expr.operation, ", ".join(expr.inames), expr.expr)
@@ -177,6 +195,12 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
 
 
 class DependencyMapper(DependencyMapperBase):
+    def map_group_hw_index(self, expr):
+        return set()
+
+    def map_local_hw_index(self, expr):
+        return set()
+
     def map_call(self, expr, *args):
         # Loopy does not have first-class functions. Do not descend
         # into 'function' attribute of Call.
@@ -235,6 +259,27 @@ class SubstitutionRuleExpander(IdentityMapper):
 
 # {{{ loopy-specific primitives
 
+class HardwareAxisIndex(Leaf):
+    def __init__(self, axis):
+        self.axis = axis
+
+    def stringifier(self):
+        return StringifyMapper
+
+    def __getinitargs__(self):
+        return (self.axis,)
+
+    init_arg_names = ("axis",)
+
+
+class GroupHardwareAxisIndex(HardwareAxisIndex):
+    mapper_method = "map_group_hw_index"
+
+
+class LocalHardwareAxisIndex(HardwareAxisIndex):
+    mapper_method = "map_local_hw_index"
+
+
 class FunctionIdentifier(Leaf):
     """A base class for symbols representing functions."""
 
diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py
index 5b51808e0..b8c903fc1 100644
--- a/loopy/target/__init__.py
+++ b/loopy/target/__init__.py
@@ -100,12 +100,6 @@ class TargetBase(object):
     def get_expression_to_code_mapper(self, codegen_state):
         raise NotImplementedError()
 
-    def get_global_axis_expr(self, kernel, axis):
-        raise NotImplementedError()
-
-    def get_local_axis_expr(self, kernel, axis):
-        raise NotImplementedError()
-
     def add_vector_access(self, access_str, index):
         raise NotImplementedError()
 
diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index 97bec6e59..061518558 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -670,6 +670,12 @@ class LoopyCCodeMapper(RecursiveMapper):
 
     # }}}
 
+    def map_group_hw_index(self, expr, enclosing_prec, type_context):
+        raise LoopyError("plain C does not have group hw axes")
+
+    def map_local_hw_index(self, expr, enclosing_prec, type_context):
+        raise LoopyError("plain C does not have group hw axes")
+
 # }}}
 
 # vim: fdm=marker
diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py
index 992d5db85..9a9aee76a 100644
--- a/loopy/target/cuda.py
+++ b/loopy/target/cuda.py
@@ -27,11 +27,10 @@ THE SOFTWARE.
 import numpy as np
 
 from loopy.target.c import CTarget
+from loopy.target.c.codegen.expression import LoopyCCodeMapper
 from pytools import memoize_method
 from loopy.diagnostic import LoopyError
 
-from pymbolic import var
-
 
 # {{{ vector types
 
@@ -135,6 +134,31 @@ def cuda_function_mangler(kernel, name, arg_dtypes):
 # }}}
 
 
+# {{{ expression mapper
+
+class LoopyCudaCCodeMapper(LoopyCCodeMapper):
+    _GRID_AXES = "xyz"
+
+    @staticmethod
+    def _get_index_ctype(kernel):
+        if kernel.index_dtype == np.int32:
+            return "int32_t"
+        else:
+            return "int64_t"
+
+    def map_group_hw_index(self, expr, enclosing_prec, type_context):
+        return "((%s) blockIdx.%s)" % (
+            self._get_index_ctype(self.kernel),
+            self._GRID_AXES[expr.axis])
+
+    def map_local_hw_index(self, expr, enclosing_prec, type_context):
+        return "((%s) threadIdx.%s)" % (
+            self._get_index_ctype(self.kernel),
+            self._GRID_AXES[expr.axis])
+
+# }}}
+
+
 # {{{ target
 
 class CudaTarget(CTarget):
@@ -216,24 +240,8 @@ class CudaTarget(CTarget):
 
     # {{{ code generation guts
 
-    _GRID_AXES = "xyz"
-
-    @staticmethod
-    def _get_index_ctype(kernel):
-        if kernel.index_dtype == np.int32:
-            return "int32_t"
-        else:
-            return "int64_t"
-
-    def get_global_axis_expr(self, kernel, axis):
-        return var("((%s) blockIdx.%s)" % (
-            self._get_index_ctype(kernel),
-            self._GRID_AXES[axis]))
-
-    def get_local_axis_expr(self, kernel, axis):
-        return var("((%s) threadIdx.%s)" % (
-            self._get_index_ctype(kernel),
-            self._GRID_AXES[axis]))
+    def get_expression_to_code_mapper(self, codegen_state):
+        return LoopyCudaCCodeMapper(codegen_state)
 
     _VEC_AXES = "xyzw"
 
diff --git a/loopy/target/ispc.py b/loopy/target/ispc.py
index cf11092df..2d146e82a 100644
--- a/loopy/target/ispc.py
+++ b/loopy/target/ispc.py
@@ -27,11 +27,27 @@ THE SOFTWARE.
 
 import numpy as np  # noqa
 from loopy.target.c import CTarget
+from loopy.target.c.codegen.expression import LoopyCCodeMapper
 from loopy.diagnostic import LoopyError
 
 from pymbolic import var
 
 
+# {{{ expression mapper
+
+class LoopyISPCCodeMapper(LoopyCCodeMapper):
+    def map_group_hw_index(self, expr, enclosing_prec, type_context):
+        return "taskIndex%d" % expr.axis
+
+    def map_local_hw_index(self, expr, enclosing_prec, type_context):
+        if expr.axis == 0:
+            return var("programIndex")
+        else:
+            raise LoopyError("ISPC only supports one local axis")
+
+# }}}
+
+
 class ISPCTarget(CTarget):
     # {{{ top-level codegen
 
@@ -101,14 +117,8 @@ class ISPCTarget(CTarget):
 
     # {{{ code generation guts
 
-    def get_global_axis_expr(self, kernel, axis):
-        return var("taskIndex%d" % axis)
-
-    def get_local_axis_expr(self, kernel, axis):
-        if axis == 0:
-            return var("programIndex")
-        else:
-            raise LoopyError("ISPC only supports one local axis")
+    def get_expression_to_code_mapper(self, codegen_state):
+        return LoopyISPCCodeMapper(codegen_state)
 
     def add_vector_access(self, access_str, index):
         return "(%s)[%d]" % (access_str, index)
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index cf130a095..7ef944d32 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -27,11 +27,10 @@ THE SOFTWARE.
 import numpy as np
 
 from loopy.target.c import CTarget
+from loopy.target.c.codegen.expression import LoopyCCodeMapper
 from pytools import memoize_method
 from loopy.diagnostic import LoopyError
 
-from pymbolic import var
-
 
 # {{{ vector types
 
@@ -175,6 +174,18 @@ def opencl_preamble_generator(kernel, seen_dtypes, seen_functions):
 # }}}
 
 
+# {{{ expression mapper
+
+class LoopyOpenCLCCodeMapper(LoopyCCodeMapper):
+    def map_group_hw_index(self, expr, enclosing_prec, type_context):
+        return "gid(%d)" % expr.axis
+
+    def map_local_hw_index(self, expr, enclosing_prec, type_context):
+        return "lid(%d)" % expr.axis
+
+# }}}
+
+
 # {{{ target
 
 class OpenCLTarget(CTarget):
@@ -267,11 +278,8 @@ class OpenCLTarget(CTarget):
 
     # {{{ code generation guts
 
-    def get_global_axis_expr(self, kernel, axis):
-        return var("gid")(axis)
-
-    def get_local_axis_expr(self, kernel, axis):
-        return var("lid")(axis)
+    def get_expression_to_code_mapper(self, codegen_state):
+        return LoopyOpenCLCCodeMapper(codegen_state)
 
     def add_vector_access(self, access_str, index):
         # The 'int' avoids an 'L' suffix for long ints.
-- 
GitLab