From 5a67d7da4ff2824d13f1371903fd652c3518c51a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 30 Apr 2013 23:41:06 -0400 Subject: [PATCH] Use AccessRangeMapper to determine temp var shapes, too. --- loopy/kernel/creation.py | 114 +++++++++++++++++++-------------------- loopy/kernel/tools.py | 13 +++-- 2 files changed, 67 insertions(+), 60 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 94bc195a4..afa70f385 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -36,7 +36,7 @@ from islpy import dim_type import re -# {{{ unique name generation +# {{{ tool: unique name generation def generate_unique_possibilities(prefix): yield prefix @@ -85,6 +85,49 @@ class MakeUnique: # }}} +# {{{ tool: access range mapper + +class AccessRangeMapper(WalkMapper): + def __init__(self, arg_name): + self.arg_name = arg_name + self.access_range = None + + def map_subscript(self, expr, domain): + WalkMapper.map_subscript(self, expr, domain) + + from pymbolic.primitives import Variable + assert isinstance(expr.aggregate, Variable) + + if expr.aggregate.name != self.arg_name: + return + + subscript = expr.index + if not isinstance(subscript, tuple): + subscript = (subscript,) + + from loopy.symbolic import get_dependencies, get_access_range + + if not get_dependencies(subscript) <= set(domain.get_var_dict()): + raise RuntimeError("cannot determine access range for '%s': " + "undetermined index in '%s'" + % (self.arg_name, ", ".join(str(i) for i in subscript))) + + access_range = get_access_range(domain, subscript) + + if self.access_range is None: + self.access_range = access_range + else: + if (self.access_range.dim(dim_type.set) + != access_range.dim(dim_type.set)): + raise RuntimeError( + "error while determining shape of argument '%s': " + "varying number of indices encountered" + % self.arg_name) + + self.access_range = self.access_range | access_range + +# }}} + # {{{ expand defines WORD_RE = re.compile(r"\b([a-zA-Z0-9_]+)\b") @@ -529,22 +572,18 @@ def create_temporaries(knl): if insn.temp_var_type is not None: assignee_name = insn.get_assignee_var_name() - assignee_indices = [] - from pymbolic.primitives import Variable - for index_expr in insn.get_assignee_indices(): - if (not isinstance(index_expr, Variable) - or not index_expr.name in knl.all_inames()): - raise RuntimeError( - "only plain inames are allowed in " - "the lvalue index when declaring the " - "variable '%s' in an instruction" - % assignee_name) + armap = AccessRangeMapper(assignee_name) + domain = knl.get_inames_domain(knl.insn_inames(insn)) + armap(insn.assignee, domain) - assignee_indices.append(index_expr.name) - - base_indices, shape = \ - knl.find_var_base_indices_and_shape_from_inames( - assignee_indices, knl.cache_manager) + if armap.access_range is not None: + base_indices, shape = zip(*[ + knl.cache_manager.base_index_and_length( + armap.access_range, i) + for i in xrange(armap.access_range.dim(dim_type.set))]) + else: + base_indices = () + shape = () if assignee_name in new_temp_vars: raise RuntimeError("cannot create temporary variable '%s'--" @@ -595,7 +634,7 @@ def check_for_reduction_inames_duplication_requests(kernel): # }}} -# {{{ +# {{{ apply default_order to args def apply_default_order_to_args(kernel, default_order): from loopy.kernel.data import ShapedArg @@ -634,45 +673,6 @@ def dup_args_and_expand_defines_in_shapes(kernel, defines): # {{{ guess argument shapes -class _AccessRangeMapper(WalkMapper): - def __init__(self, arg_name): - self.arg_name = arg_name - self.access_range = None - - def map_subscript(self, expr, domain): - WalkMapper.map_subscript(self, expr, domain) - - from pymbolic.primitives import Variable - assert isinstance(expr.aggregate, Variable) - - if expr.aggregate.name != self.arg_name: - return - - subscript = expr.index - if not isinstance(subscript, tuple): - subscript = (subscript,) - - from loopy.symbolic import get_dependencies, get_access_range - - if not get_dependencies(subscript) <= set(domain.get_var_dict()): - raise RuntimeError("cannot determine access range for '%s': " - "undetermined index in '%s'" - % (self.arg_name, ", ".join(str(i) for i in subscript))) - - access_range = get_access_range(domain, subscript) - - if self.access_range is None: - self.access_range = access_range - else: - if (self.access_range.dim(dim_type.set) - != access_range.dim(dim_type.set)): - raise RuntimeError( - "error while determining shape of argument '%s': " - "varying number of indices encountered" - % self.arg_name) - - self.access_range = self.access_range | access_range - def guess_arg_shape_if_requested(kernel, default_order): new_args = [] @@ -681,7 +681,7 @@ def guess_arg_shape_if_requested(kernel, default_order): for arg in kernel.args: if isinstance(arg, ShapedArg) and ( arg.shape is auto_shape or arg.strides is auto_strides): - armap = _AccessRangeMapper(arg.name) + armap = AccessRangeMapper(arg.name) for insn in kernel.instructions: domain = kernel.get_inames_domain(kernel.insn_inames(insn)) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 1ea990758..8ada7da8b 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -254,9 +254,16 @@ class SetOperationCacheManager: return self.op(set, "dim_max", set.dim_max, args) def base_index_and_length(self, set, iname, context=None): - iname_to_dim = set.space.get_var_dict() - lower_bound_pw_aff = self.dim_min(set, iname_to_dim[iname][1]) - upper_bound_pw_aff = self.dim_max(set, iname_to_dim[iname][1]) + if not isinstance(iname, int): + iname_to_dim = set.space.get_var_dict() + idx = iname_to_dim[iname][1] + else: + idx = iname + + del iname + + lower_bound_pw_aff = self.dim_min(set, idx) + upper_bound_pw_aff = self.dim_max(set, idx) from loopy.isl_helpers import static_max_of_pw_aff, static_value_of_pw_aff from loopy.symbolic import pw_aff_to_expr -- GitLab