From 1a1ed4dd4985bdbebb2fad4b4244d0f8de3b8ffa Mon Sep 17 00:00:00 2001 From: Tim Warburton <timwar@caam.rice.edu> Date: Tue, 1 Nov 2011 23:37:14 -0500 Subject: [PATCH] Make temp. variable shapes tuples of ints (not PwAffs). --- loopy/codegen/__init__.py | 2 +- loopy/kernel.py | 7 +++---- loopy/symbolic.py | 5 +++++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index b093c44df..dd4b4dec5 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -283,7 +283,7 @@ def generate_code(kernel): from loopy.symbolic import pw_aff_to_expr for l in storage_shape: - temp_var_decl = ArrayOf(temp_var_decl, int(pw_aff_to_expr(l))) + temp_var_decl = ArrayOf(temp_var_decl, l) if tv.is_local: temp_var_decl = CLLocal(temp_var_decl) diff --git a/loopy/kernel.py b/loopy/kernel.py index 2a56ff718..52974a83c 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -203,8 +203,7 @@ class TemporaryVariable(Record): @property def nbytes(self): from pytools import product - from loopy.symbolic import pw_aff_to_expr - return product(pw_aff_to_expr(si) for si in self.shape)*self.dtype.itemsize + return product(si for si in self.shape)*self.dtype.itemsize # }}} @@ -858,8 +857,8 @@ def find_var_base_indices_and_shape_from_inames(domain, inames): from loopy.isl_helpers import static_max_of_pw_aff from loopy.symbolic import pw_aff_to_expr - shape.append(static_max_of_pw_aff( - upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True)) + shape.append(pw_aff_to_expr(static_max_of_pw_aff( + upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True))) base_indices.append(pw_aff_to_expr(lower_bound_pw_aff)) return base_indices, shape diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 8f28afb33..ba4f05b4f 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -432,6 +432,11 @@ def aff_to_expr(aff, except_name=None, error_on_name=None): def pw_aff_to_expr(pw_aff): + if isinstance(pw_aff, int): + from warnings import warn + warn("expected PwAff, got int", stacklevel=2) + return pw_aff + pieces = pw_aff.get_pieces() if len(pieces) != 1: -- GitLab