diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 3920899386493b3dff350aadd55a3a74b990c6d8..8d340114294b53dc822d9e2fe0dd0309e32e8ce4 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -85,48 +85,6 @@ class MakeUnique: # }}} -# {{{ tool: access range mapper - -class AccessRangeMapper(WalkMapper): - def __init__(self, arg_name): - self.arg_name = arg_name - self.access_range = None - - def map_subscript(self, expr, domain): - WalkMapper.map_subscript(self, expr, domain) - - from pymbolic.primitives import Variable - assert isinstance(expr.aggregate, Variable) - - if expr.aggregate.name != self.arg_name: - return - - subscript = expr.index - if not isinstance(subscript, tuple): - subscript = (subscript,) - - from loopy.symbolic import get_dependencies, get_access_range - - if not get_dependencies(subscript) <= set(domain.get_var_dict()): - raise RuntimeError("cannot determine access range for '%s': " - "undetermined index in '%s'" - % (self.arg_name, ", ".join(str(i) for i in subscript))) - - access_range = get_access_range(domain, subscript) - - if self.access_range is None: - self.access_range = access_range - else: - if (self.access_range.dim(dim_type.set) - != access_range.dim(dim_type.set)): - raise RuntimeError( - "error while determining shape of argument '%s': " - "varying number of indices encountered" - % self.arg_name) - - self.access_range = self.access_range | access_range - -# }}} # {{{ expand defines @@ -669,15 +627,16 @@ def create_temporaries(knl): new_insns = [] new_temp_vars = knl.temporary_variables.copy() + from loopy.symbolic import AccessRangeMapper + for insn in knl.instructions: from loopy.kernel.data import TemporaryVariable if insn.temp_var_type is not None: assignee_name = insn.get_assignee_var_name() - armap = AccessRangeMapper(assignee_name) - domain = knl.get_inames_domain(knl.insn_inames(insn)) - armap(insn.assignee, domain) + armap = AccessRangeMapper(knl, assignee_name) + armap(insn.assignee, knl.insn_inames(insn)) if armap.access_range is not None: base_indices, shape = zip(*[ @@ -737,21 +696,6 @@ def check_for_reduction_inames_duplication_requests(kernel): # }}} -# {{{ apply default_order to args - -def apply_default_order_to_args(kernel, default_order): - from loopy.kernel.data import ShapedArg - - processed_args = [] - for arg in kernel.args: - if isinstance(arg, ShapedArg): - arg = arg.copy(order=default_order) - processed_args.append(arg) - - return kernel.copy(args=processed_args) - -# }}} - # {{{ duplicate arguments and expand defines in shapes def dup_args_and_expand_defines_in_shapes(kernel, defines): @@ -785,7 +729,7 @@ def guess_arg_shape_if_requested(kernel, default_order): import loopy as lp from loopy.kernel.data import ShapedArg - from loopy.symbolic import SubstitutionRuleExpander + from loopy.symbolic import SubstitutionRuleExpander, AccessRangeMapper submap = SubstitutionRuleExpander(kernel.substitutions, kernel.get_var_name_generator()) @@ -793,12 +737,11 @@ def guess_arg_shape_if_requested(kernel, default_order): for arg in kernel.args: if isinstance(arg, ShapedArg) and ( arg.shape is lp.auto or arg.strides is lp.auto): - armap = AccessRangeMapper(arg.name) + armap = AccessRangeMapper(kernel, arg.name) for insn in kernel.instructions: - domain = kernel.get_inames_domain(kernel.insn_inames(insn)) - armap(submap(insn.assignee, insn.id), domain) - armap(submap(insn.expression, insn.id), domain) + armap(submap(insn.assignee, insn.id), kernel.insn_inames(insn)) + armap(submap(insn.expression, insn.id), kernel.insn_inames(insn)) if armap.access_range is None: # no subscripts found, let's call it a scalar @@ -825,6 +768,21 @@ def guess_arg_shape_if_requested(kernel, default_order): # }}} +# {{{ apply default_order to args + +def apply_default_order_to_args(kernel, default_order): + from loopy.kernel.data import ShapedArg + + processed_args = [] + for arg in kernel.args: + if isinstance(arg, ShapedArg): + arg = arg.copy(order=default_order) + processed_args.append(arg) + + return kernel.copy(args=processed_args) + +# }}} + # {{{ kernel creation top-level def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs): diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 91f7e4907f281b586317c221eef04cbe99002d39..7267d7be647c7869ffb1ef66179ddf8c8e0b15b0 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -980,6 +980,54 @@ def get_access_range(domain, subscript): # }}} +# {{{ access range mapper + +class AccessRangeMapper(WalkMapper): + def __init__(self, kernel, arg_name): + self.kernel = kernel + self.arg_name = arg_name + self.access_range = None + + def map_subscript(self, expr, inames): + domain = self.kernel.get_inames_domain(inames) + WalkMapper.map_subscript(self, expr, domain) + + from pymbolic.primitives import Variable + assert isinstance(expr.aggregate, Variable) + + if expr.aggregate.name != self.arg_name: + return + + subscript = expr.index + if not isinstance(subscript, tuple): + subscript = (subscript,) + + from loopy.symbolic import get_dependencies, get_access_range + + if not get_dependencies(subscript) <= set(domain.get_var_dict()): + raise RuntimeError("cannot determine access range for '%s': " + "undetermined index in '%s'" + % (self.arg_name, ", ".join(str(i) for i in subscript))) + + access_range = get_access_range(domain, subscript) + + if self.access_range is None: + self.access_range = access_range + else: + if (self.access_range.dim(dim_type.set) + != access_range.dim(dim_type.set)): + raise RuntimeError( + "error while determining shape of argument '%s': " + "varying number of indices encountered" + % self.arg_name) + + self.access_range = self.access_range | access_range + + def map_reduction(self, expr, inames): + return WalkMapper.map_reduction(self, expr, inames | set(expr.inames)) + +# }}} + # vim: foldmethod=marker