diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 019b899594ee1670d1ee59654f58b93ff8afacb9..d1efd4a44254a6d829ce63b66bb228cea1efab6a 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1392,11 +1392,11 @@ def create_temporaries(knl, default_order): # {{{ determine shapes of temporaries -def find_var_shape(knl, var_name, feed_expression): - from loopy.symbolic import AccessRangeMapper, SubstitutionRuleExpander +def find_shapes_of_vars(knl, var_names, feed_expression): + from loopy.symbolic import BatchedAccessRangeMapper, SubstitutionRuleExpander submap = SubstitutionRuleExpander(knl.substitutions) - armap = AccessRangeMapper(knl, var_name) + armap = BatchedAccessRangeMapper(knl, var_names) def run_through_armap(expr, inames): armap(submap(expr), inames) @@ -1404,61 +1404,105 @@ def find_var_shape(knl, var_name, feed_expression): feed_expression(run_through_armap) - if armap.access_range is not None: - base_indices, shape = list(zip(*[ - knl.cache_manager.base_index_and_length( - armap.access_range, i) - for i in range(armap.access_range.dim(dim_type.set))])) - else: - if armap.bad_subscripts: - raise RuntimeError("cannot determine access range for '%s': " - "undetermined index in subscript(s) '%s'" - % (var_name, ", ".join( - str(i) for i in armap.bad_subscripts))) + var_to_base_indices = {} + var_to_shape = {} + var_to_error = {} + + from loopy.diagnostic import StaticValueFindingError + + for var_name in var_names: + access_range = armap.access_ranges[var_name] + bad_subscripts = armap.bad_subscripts[var_name] + + if access_range is not None: + try: + base_indices, shape = list(zip(*[ + knl.cache_manager.base_index_and_length( + access_range, i) + for i in range(access_range.dim(dim_type.set))])) + except StaticValueFindingError as e: + var_to_error[var_name] = str(e) + continue + + else: + if bad_subscripts: + raise RuntimeError("cannot determine access range for '%s': " + "undetermined index in subscript(s) '%s'" + % (var_name, ", ".join( + str(i) for i in bad_subscripts))) + + # no subscripts found, let's call it a scalar + base_indices = () + shape = () - # no subscripts found, let's call it a scalar - base_indices = () - shape = () + var_to_base_indices[var_name] = base_indices + var_to_shape[var_name] = shape - return base_indices, shape + return var_to_base_indices, var_to_shape, var_to_error def determine_shapes_of_temporaries(knl): new_temp_vars = knl.temporary_variables.copy() import loopy as lp - from loopy.diagnostic import StaticValueFindingError - new_temp_vars = {} + vars_needing_shape_inference = set() + for tv in six.itervalues(knl.temporary_variables): if tv.shape is lp.auto or tv.base_indices is lp.auto: - def feed_all_expressions(receiver): - for insn in knl.instructions: - insn.with_transformed_expressions( - lambda expr: receiver(expr, knl.insn_inames(insn))) + vars_needing_shape_inference.add(tv.name) - def feed_assignee_of_instruction(receiver): - for insn in knl.instructions: - for assignee in insn.assignees: - receiver(assignee, knl.insn_inames(insn)) + def feed_all_expressions(receiver): + for insn in knl.instructions: + insn.with_transformed_expressions( + lambda expr: receiver(expr, knl.insn_inames(insn))) - try: - base_indices, shape = find_var_shape( - knl, tv.name, feed_all_expressions) - except StaticValueFindingError as e: - warn_with_kernel(knl, "temp_shape_fallback", - "Had to fall back to legacy method of determining " - "shape of temporary '%s' because: %s" - % (tv.name, str(e))) + var_to_base_indices, var_to_shape, var_to_error = ( + find_shapes_of_vars( + knl, vars_needing_shape_inference, feed_all_expressions)) + + # {{{ fall back to legacy method + + if len(var_to_error) > 0: + vars_needing_shape_inference = set(var_to_error.keys()) + + from six import iteritems + for varname, err in iteritems(var_to_error): + warn_with_kernel(knl, "temp_shape_fallback", + "Had to fall back to legacy method of determining " + "shape of temporary '%s' because: %s" + % (varname, err)) + + def feed_assignee_of_instruction(receiver): + for insn in knl.instructions: + for assignee in insn.assignees: + receiver(assignee, knl.insn_inames(insn)) + + var_to_base_indices_fallback, var_to_shape_fallback, var_to_error = ( + find_shapes_of_vars( + knl, vars_needing_shape_inference, feed_assignee_of_instruction)) - base_indices, shape = find_var_shape( - knl, tv.name, feed_assignee_of_instruction) + if len(var_to_error) > 0: + # No way around errors: propagate an exception upward. + formatted_errors = ( + "\n\n".join("'%s': %s" % (varname, var_to_error[varname]) + for varname in sorted(var_to_error.keys()))) - if tv.base_indices is lp.auto: - tv = tv.copy(base_indices=base_indices) - if tv.shape is lp.auto: - tv = tv.copy(shape=shape) + raise LoopyError("got the following exception(s) trying to find the " + "shape of temporary variables: %s" % formatted_errors) + var_to_base_indices.update(var_to_base_indices_fallback) + var_to_shape.update(var_to_shape_fallback) + + # }}} + + new_temp_vars = {} + + for tv in six.itervalues(knl.temporary_variables): + if tv.base_indices is lp.auto: + tv = tv.copy(base_indices=var_to_base_indices[tv.name]) + if tv.shape is lp.auto: + tv = tv.copy(shape=var_to_shape[tv.name]) new_temp_vars[tv.name] = tv return knl.copy(temporary_variables=new_temp_vars) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 52fd6e57f92e7f9599a3a0fb4256f97347708303..b14fba5706b83c94a86b66079925939567d60594 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1471,12 +1471,13 @@ def get_access_range(domain, subscript, assumptions): # {{{ access range mapper -class AccessRangeMapper(WalkMapper): - def __init__(self, kernel, arg_name): +class BatchedAccessRangeMapper(WalkMapper): + + def __init__(self, kernel, arg_names): self.kernel = kernel - self.arg_name = arg_name - self.access_range = None - self.bad_subscripts = [] + self.arg_names = set(arg_names) + self.access_ranges = dict((arg, None) for arg in arg_names) + self.bad_subscripts = dict((arg, []) for arg in arg_names) def map_subscript(self, expr, inames): domain = self.kernel.get_inames_domain(inames) @@ -1484,38 +1485,58 @@ class AccessRangeMapper(WalkMapper): assert isinstance(expr.aggregate, p.Variable) - if expr.aggregate.name != self.arg_name: + if expr.aggregate.name not in self.arg_names: return + arg_name = expr.aggregate.name subscript = expr.index_tuple if not get_dependencies(subscript) <= set(domain.get_var_dict()): - self.bad_subscripts.append(expr) + self.bad_subscripts[arg_name].append(expr) return access_range = get_access_range(domain, subscript, self.kernel.assumptions) - if self.access_range is None: - self.access_range = access_range + if self.access_ranges[arg_name] is None: + self.access_ranges[arg_name] = access_range else: - if (self.access_range.dim(dim_type.set) + if (self.access_ranges[arg_name].dim(dim_type.set) != access_range.dim(dim_type.set)): raise RuntimeError( "error while determining shape of argument '%s': " "varying number of indices encountered" - % self.arg_name) + % arg_name) - self.access_range = self.access_range | access_range + self.access_ranges[arg_name] = ( + self.access_ranges[arg_name] | access_range) def map_linear_subscript(self, expr, inames): self.rec(expr.index, inames) - if expr.aggregate.name == self.arg_name: - self.bad_subscripts.append(expr) + if expr.aggregate.name in self.arg_names: + self.bad_subscripts[expr.aggregate.name].append(expr) def map_reduction(self, expr, inames): return WalkMapper.map_reduction(self, expr, inames | set(expr.inames)) + +class AccessRangeMapper(object): + + def __init__(self, kernel, arg_name): + self.arg_name = arg_name + self.inner_mapper = BatchedAccessRangeMapper(kernel, [arg_name]) + + def __call__(self, expr, inames): + return self.inner_mapper(expr, inames) + + @property + def access_range(self): + return self.inner_mapper.access_ranges[self.arg_name] + + @property + def bad_subscripts(self): + return self.inner_mapper.bad_subscripts[self.arg_name] + # }}}