diff --git a/loopy/array_buffer_map.py b/loopy/array_buffer_map.py index 0be16ce5b3d9b7af477f7ee849d0dbe8e6b90c7c..72fca8a4a17a183251c877e41b033539e2f00c85 100644 --- a/loopy/array_buffer_map.py +++ b/loopy/array_buffer_map.py @@ -28,7 +28,7 @@ from islpy import dim_type from loopy.symbolic import (get_dependencies, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func -from pytools import Record +from pytools import Record, memoize_method from pymbolic import var @@ -64,7 +64,7 @@ def to_parameters_or_project_out(param_inames, set_inames, set): # {{{ construct storage->sweep map -def build_per_access_storage_to_domain_map(accdesc, domain, +def build_per_access_storage_to_domain_map(storage_axis_exprs, domain, storage_axis_names, prime_sweep_inames): @@ -91,7 +91,7 @@ def build_per_access_storage_to_domain_map(accdesc, domain, from loopy.symbolic import aff_from_expr - for saxis, sa_expr in zip(storage_axis_names, accdesc.storage_axis_exprs): + for saxis, sa_expr in zip(storage_axis_names, storage_axis_exprs): cns = isl.Constraint.equality_from_aff( aff_from_expr(set_space, var(saxis+"'") - prime_sweep_inames(sa_expr))) @@ -138,7 +138,7 @@ def build_global_storage_to_sweep_map(kernel, access_descriptors, # build footprint for accdesc in access_descriptors: stor2sweep = build_per_access_storage_to_domain_map( - accdesc, domain_dup_sweep, + accdesc.storage_axis_exprs, domain_dup_sweep, storage_axis_names, prime_sweep_inames) @@ -336,6 +336,11 @@ class ArrayToBufferMap(object): return convexify(domain) def is_access_descriptor_in_footprint(self, accdesc): + return self._is_access_descriptor_in_footprint_inner( + tuple(accdesc.storage_axis_exprs)) + + @memoize_method + def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs): # Make all inames except the sweep parameters. (The footprint may depend on # those.) (I.e. only leave sweep inames as out parameters.) @@ -347,7 +352,7 @@ class ArrayToBufferMap(object): set(global_s2s_par_dom.get_var_names(dim_type.param)) & self.kernel.all_inames()) - for arg in accdesc.storage_axis_exprs: + for arg in storage_axis_exprs: arg_inames.update(get_dependencies(arg)) arg_inames = frozenset(arg_inames) @@ -363,7 +368,8 @@ class ArrayToBufferMap(object): usage_domain = usage_domain.set_dim_name( dim_type.set, i, iname+"'") - stor2sweep = build_per_access_storage_to_domain_map(accdesc, + stor2sweep = build_per_access_storage_to_domain_map( + storage_axis_exprs, usage_domain, self.storage_axis_names, self.prime_sweep_inames)