diff --git a/loopy/array_buffer_map.py b/loopy/array_buffer_map.py index 6d1c56eef27e6e809dfc6b4048f7a570d1c29a10..9189421271ea28c95be861c785243aabbd7bd934 100644 --- a/loopy/array_buffer_map.py +++ b/loopy/array_buffer_map.py @@ -42,7 +42,6 @@ class AccessDescriptor(Record): __slots__ = [ "identifier", - "expands_footprint", "storage_axis_exprs", ] @@ -138,16 +137,15 @@ def build_global_storage_to_sweep_map(kernel, access_descriptors, # build footprint for accdesc in access_descriptors: - if accdesc.expands_footprint: - stor2sweep = build_per_access_storage_to_domain_map( - accdesc, domain_dup_sweep, - storage_axis_names, - prime_sweep_inames) + stor2sweep = build_per_access_storage_to_domain_map( + accdesc, domain_dup_sweep, + storage_axis_names, + prime_sweep_inames) - if global_stor2sweep is None: - global_stor2sweep = stor2sweep - else: - global_stor2sweep = global_stor2sweep.union(stor2sweep) + if global_stor2sweep is None: + global_stor2sweep = stor2sweep + else: + global_stor2sweep = global_stor2sweep.union(stor2sweep) if isinstance(global_stor2sweep, isl.BasicMap): global_stor2sweep = isl.Map.from_basic_map(global_stor2sweep) @@ -336,9 +334,6 @@ class ArrayToBufferMap(object): return convexify(domain) def is_access_descriptor_in_footprint(self, accdesc): - if accdesc.expands_footprint: - return True - # Make all inames except the sweep parameters. (The footprint may depend on # those.) (I.e. only leave sweep inames as out parameters.) diff --git a/loopy/buffer.py b/loopy/buffer.py index e8e15dbdf1a95a0d21007d4245521dd5a04b0a24..2a7386af876ecf7159dd272387bb968d2720edd8 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -74,7 +74,6 @@ class ArrayAccessReplacer(ExpandingIdentityMapper): def map_array_access(self, index, expn_state): accdesc = AccessDescriptor( identifier=None, - expands_footprint=False, storage_axis_exprs=index) if not self.array_base_map.is_access_descriptor_in_footprint(accdesc): @@ -182,7 +181,6 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, access_descriptors.append( AccessDescriptor( identifier=insn.id, - expands_footprint=True, storage_axis_exprs=index)) # {{{ find fetch/store inames diff --git a/loopy/precompute.py b/loopy/precompute.py index 15b46ecd648ab9289c4d2a467dd5b5cbaa6a4be4..ce6711f073eec20e13b93089b5fb309e751aee42 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -114,7 +114,6 @@ class RuleInvocationGatherer(ExpandingIdentityMapper): args = [arg_context[arg_name] for arg_name in rule.arguments] - # Do not set expands_footprint here, it is set below. self.access_descriptors.append( RuleAccessDescriptor( identifier=access_descriptor_id(args, expn_state.stack), @@ -152,42 +151,30 @@ class RuleInvocationReplacer(ExpandingIdentityMapper): self.target_var_name = target_var_name def map_substitution(self, name, tag, arguments, expn_state): - process_me = name == self.subst_name - - if self.subst_tag is not None and self.subst_tag != tag: - process_me = False - - process_me = process_me and self.within(expn_state.stack) + if not ( + name == self.subst_name + and self.within(expn_state.stack) + and (self.subst_tag is None or self.subst_tag == tag)): + return ExpandingIdentityMapper.map_substitution( + self, name, tag, arguments, expn_state) - # {{{ find matching invocation descriptor + # {{{ check if in footprint rule = self.old_subst_rules[name] arg_context = self.make_new_arg_context( name, rule.arguments, arguments, expn_state.arg_context) args = [arg_context[arg_name] for arg_name in rule.arguments] - if not process_me: - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) - - matching_accdesc = None - for accdesc in self.access_descriptors: - if accdesc.identifier == access_descriptor_id(args, expn_state.stack): - # Could be more than one, that's fine. - matching_accdesc = accdesc - break - - assert matching_accdesc is not None - - accdesc = matching_accdesc - del matching_accdesc - - # }}} + accdesc = AccessDescriptor( + storage_axis_exprs=storage_axis_exprs( + self.storage_axis_sources, args)) if not self.array_base_map.is_access_descriptor_in_footprint(accdesc): return ExpandingIdentityMapper.map_substitution( self, name, tag, arguments, expn_state) + # }}} + assert len(arguments) == len(rule.arguments) abm = self.array_base_map @@ -374,7 +361,6 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, access_descriptors.append( RuleAccessDescriptor( identifier=access_descriptor_id(args, None), - expands_footprint=True, args=args )) @@ -382,19 +368,16 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # {{{ gather up invocations in kernel code, finish access_descriptors - invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within) - - import loopy as lp - for insn in kernel.instructions: - if isinstance(insn, lp.ExpressionInstruction): - invg(insn.expression, insn.id, insn.tags) + if not footprint_generators: + invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within) - for accdesc in invg.access_descriptors: - access_descriptors.append( - accdesc.copy(expands_footprint=footprint_generators is None)) + import loopy as lp + for insn in kernel.instructions: + if isinstance(insn, lp.ExpressionInstruction): + invg(insn.expression, insn.id, insn.tags) - if not access_descriptors: - raise RuntimeError("no invocations of '%s' found" % subst_name) + if not access_descriptors: + raise RuntimeError("no invocations of '%s' found" % subst_name) # }}} @@ -403,10 +386,9 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, expanding_usage_arg_deps = set() for accdesc in access_descriptors: - if accdesc.expands_footprint: - for arg in accdesc.args: - expanding_usage_arg_deps.update( - get_dependencies(arg) & kernel.all_inames()) + for arg in accdesc.args: + expanding_usage_arg_deps.update( + get_dependencies(arg) & kernel.all_inames()) # }}}