From 5a67d7da4ff2824d13f1371903fd652c3518c51a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 30 Apr 2013 23:41:06 -0400
Subject: [PATCH] Use AccessRangeMapper to determine temp var shapes, too.

---
 loopy/kernel/creation.py | 114 +++++++++++++++++++--------------------
 loopy/kernel/tools.py    |  13 +++--
 2 files changed, 67 insertions(+), 60 deletions(-)

diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 94bc195a4..afa70f385 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -36,7 +36,7 @@ from islpy import dim_type
 import re
 
 
-# {{{ unique name generation
+# {{{ tool: unique name generation
 
 def generate_unique_possibilities(prefix):
     yield prefix
@@ -85,6 +85,49 @@ class MakeUnique:
 
 # }}}
 
+# {{{ tool: access range mapper
+
+class AccessRangeMapper(WalkMapper):
+    def __init__(self, arg_name):
+        self.arg_name = arg_name
+        self.access_range = None
+
+    def map_subscript(self, expr, domain):
+        WalkMapper.map_subscript(self, expr, domain)
+
+        from pymbolic.primitives import Variable
+        assert isinstance(expr.aggregate, Variable)
+
+        if expr.aggregate.name != self.arg_name:
+            return
+
+        subscript = expr.index
+        if not isinstance(subscript, tuple):
+            subscript = (subscript,)
+
+        from loopy.symbolic import get_dependencies, get_access_range
+
+        if not get_dependencies(subscript) <= set(domain.get_var_dict()):
+            raise RuntimeError("cannot determine access range for '%s': "
+                    "undetermined index in '%s'"
+                    % (self.arg_name, ", ".join(str(i) for i in subscript)))
+
+        access_range = get_access_range(domain, subscript)
+
+        if self.access_range is None:
+            self.access_range = access_range
+        else:
+            if (self.access_range.dim(dim_type.set)
+                    != access_range.dim(dim_type.set)):
+                raise RuntimeError(
+                        "error while determining shape of argument '%s': "
+                        "varying number of indices encountered"
+                        % self.arg_name)
+
+            self.access_range = self.access_range | access_range
+
+# }}}
+
 # {{{ expand defines
 
 WORD_RE = re.compile(r"\b([a-zA-Z0-9_]+)\b")
@@ -529,22 +572,18 @@ def create_temporaries(knl):
         if insn.temp_var_type is not None:
             assignee_name = insn.get_assignee_var_name()
 
-            assignee_indices = []
-            from pymbolic.primitives import Variable
-            for index_expr in insn.get_assignee_indices():
-                if (not isinstance(index_expr, Variable)
-                        or not index_expr.name in knl.all_inames()):
-                    raise RuntimeError(
-                            "only plain inames are allowed in "
-                            "the lvalue index when declaring the "
-                            "variable '%s' in an instruction"
-                            % assignee_name)
+            armap = AccessRangeMapper(assignee_name)
+            domain = knl.get_inames_domain(knl.insn_inames(insn))
+            armap(insn.assignee, domain)
 
-                assignee_indices.append(index_expr.name)
-
-            base_indices, shape = \
-                    knl.find_var_base_indices_and_shape_from_inames(
-                            assignee_indices, knl.cache_manager)
+            if armap.access_range is not None:
+                base_indices, shape = zip(*[
+                        knl.cache_manager.base_index_and_length(
+                            armap.access_range, i)
+                        for i in xrange(armap.access_range.dim(dim_type.set))])
+            else:
+                base_indices = ()
+                shape = ()
 
             if assignee_name in new_temp_vars:
                 raise RuntimeError("cannot create temporary variable '%s'--"
@@ -595,7 +634,7 @@ def check_for_reduction_inames_duplication_requests(kernel):
 
 # }}}
 
-# {{{
+# {{{ apply default_order to args
 
 def apply_default_order_to_args(kernel, default_order):
     from loopy.kernel.data import ShapedArg
@@ -634,45 +673,6 @@ def dup_args_and_expand_defines_in_shapes(kernel, defines):
 
 # {{{ guess argument shapes
 
-class _AccessRangeMapper(WalkMapper):
-    def __init__(self, arg_name):
-        self.arg_name = arg_name
-        self.access_range = None
-
-    def map_subscript(self, expr, domain):
-        WalkMapper.map_subscript(self, expr, domain)
-
-        from pymbolic.primitives import Variable
-        assert isinstance(expr.aggregate, Variable)
-
-        if expr.aggregate.name != self.arg_name:
-            return
-
-        subscript = expr.index
-        if not isinstance(subscript, tuple):
-            subscript = (subscript,)
-
-        from loopy.symbolic import get_dependencies, get_access_range
-
-        if not get_dependencies(subscript) <= set(domain.get_var_dict()):
-            raise RuntimeError("cannot determine access range for '%s': "
-                    "undetermined index in '%s'"
-                    % (self.arg_name, ", ".join(str(i) for i in subscript)))
-
-        access_range = get_access_range(domain, subscript)
-
-        if self.access_range is None:
-            self.access_range = access_range
-        else:
-            if (self.access_range.dim(dim_type.set)
-                    != access_range.dim(dim_type.set)):
-                raise RuntimeError(
-                        "error while determining shape of argument '%s': "
-                        "varying number of indices encountered"
-                        % self.arg_name)
-
-            self.access_range = self.access_range | access_range
-
 def guess_arg_shape_if_requested(kernel, default_order):
     new_args = []
 
@@ -681,7 +681,7 @@ def guess_arg_shape_if_requested(kernel, default_order):
     for arg in kernel.args:
         if isinstance(arg, ShapedArg) and (
                 arg.shape is auto_shape or arg.strides is auto_strides):
-            armap = _AccessRangeMapper(arg.name)
+            armap = AccessRangeMapper(arg.name)
 
             for insn in kernel.instructions:
                 domain = kernel.get_inames_domain(kernel.insn_inames(insn))
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index 1ea990758..8ada7da8b 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -254,9 +254,16 @@ class SetOperationCacheManager:
         return self.op(set, "dim_max", set.dim_max, args)
 
     def base_index_and_length(self, set, iname, context=None):
-        iname_to_dim = set.space.get_var_dict()
-        lower_bound_pw_aff = self.dim_min(set, iname_to_dim[iname][1])
-        upper_bound_pw_aff = self.dim_max(set, iname_to_dim[iname][1])
+        if not isinstance(iname, int):
+            iname_to_dim = set.space.get_var_dict()
+            idx = iname_to_dim[iname][1]
+        else:
+            idx = iname
+
+        del iname
+
+        lower_bound_pw_aff = self.dim_min(set, idx)
+        upper_bound_pw_aff = self.dim_max(set, idx)
 
         from loopy.isl_helpers import static_max_of_pw_aff, static_value_of_pw_aff
         from loopy.symbolic import pw_aff_to_expr
-- 
GitLab