From 1a1ed4dd4985bdbebb2fad4b4244d0f8de3b8ffa Mon Sep 17 00:00:00 2001
From: Tim Warburton <timwar@caam.rice.edu>
Date: Tue, 1 Nov 2011 23:37:14 -0500
Subject: [PATCH] Make temp. variable shapes tuples of ints (not PwAffs).

---
 loopy/codegen/__init__.py | 2 +-
 loopy/kernel.py           | 7 +++----
 loopy/symbolic.py         | 5 +++++
 3 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index b093c44df..dd4b4dec5 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -283,7 +283,7 @@ def generate_code(kernel):
 
         from loopy.symbolic import pw_aff_to_expr
         for l in storage_shape:
-            temp_var_decl = ArrayOf(temp_var_decl, int(pw_aff_to_expr(l)))
+            temp_var_decl = ArrayOf(temp_var_decl, l)
 
         if tv.is_local:
             temp_var_decl = CLLocal(temp_var_decl)
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 2a56ff718..52974a83c 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -203,8 +203,7 @@ class TemporaryVariable(Record):
     @property
     def nbytes(self):
         from pytools import product
-        from loopy.symbolic import pw_aff_to_expr
-        return product(pw_aff_to_expr(si) for si in self.shape)*self.dtype.itemsize
+        return product(si for si in self.shape)*self.dtype.itemsize
 
 # }}}
 
@@ -858,8 +857,8 @@ def find_var_base_indices_and_shape_from_inames(domain, inames):
         from loopy.isl_helpers import static_max_of_pw_aff
         from loopy.symbolic import pw_aff_to_expr
 
-        shape.append(static_max_of_pw_aff(
-                upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True))
+        shape.append(pw_aff_to_expr(static_max_of_pw_aff(
+                upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True)))
         base_indices.append(pw_aff_to_expr(lower_bound_pw_aff))
 
     return base_indices, shape
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 8f28afb33..ba4f05b4f 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -432,6 +432,11 @@ def aff_to_expr(aff, except_name=None, error_on_name=None):
 
 
 def pw_aff_to_expr(pw_aff):
+    if isinstance(pw_aff, int):
+        from warnings import warn
+        warn("expected PwAff, got int", stacklevel=2)
+        return pw_aff
+
     pieces = pw_aff.get_pieces()
 
     if len(pieces) != 1:
-- 
GitLab