diff --git a/loopy/cse.py b/loopy/cse.py index 17f25660d596d2dbcede0aa26ab7ea5c47d5b810..8d8c7477baabe2b8f65c50d33330452c1ae0e592 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -23,8 +23,6 @@ THE SOFTWARE. """ - - import islpy as isl from islpy import dim_type from loopy.symbolic import (get_dependencies, SubstitutionMapper, @@ -36,8 +34,6 @@ from pytools import Record from pymbolic import var - - class InvocationDescriptor(Record): __slots__ = [ "args", @@ -50,8 +46,6 @@ class InvocationDescriptor(Record): ] - - def to_parameters_or_project_out(param_inames, set_inames, set): for iname in set.get_space().get_var_dict().keys(): if iname in param_inames: @@ -68,8 +62,6 @@ def to_parameters_or_project_out(param_inames, set_inames, set): return set - - # {{{ construct storage->sweep map def build_per_access_storage_to_domain_map(invdesc, domain, @@ -126,6 +118,7 @@ def build_per_access_storage_to_domain_map(invdesc, domain, # stor2sweep is back in map_space return stor2sweep + def move_to_par_from_out(s2smap, except_inames): while True: var_dict = s2smap.get_var_dict(dim_type.out) @@ -140,6 +133,7 @@ def move_to_par_from_out(s2smap, except_inames): else: return s2smap + def build_global_storage_to_sweep_map(kernel, invocation_descriptors, domain_dup_sweep, dup_sweep_index, storage_axis_names, storage_axis_sources, @@ -178,8 +172,8 @@ def build_global_storage_to_sweep_map(kernel, invocation_descriptors, # {{{ check if non-footprint-building invocation descriptors fall into footprint - # Make all inames except the sweep parameters. (The footprint may depend on those.) - # (I.e. only leave sweep inames as out parameters.) + # Make all inames except the sweep parameters. (The footprint may depend on + # those.) (I.e. only leave sweep inames as out parameters.) global_s2s_par_dom = move_to_par_from_out( global_stor2sweep, except_inames=frozenset(primed_sweep_inames)).domain() @@ -225,10 +219,12 @@ def build_global_storage_to_sweep_map(kernel, invocation_descriptors, arg_restrictions = ( aligned_g_s2s_parm_dom - .eliminate(dim_type.set, 0, aligned_g_s2s_parm_dom.dim(dim_type.set)) + .eliminate(dim_type.set, 0, + aligned_g_s2s_parm_dom.dim(dim_type.set)) .remove_divs()) - is_in_footprint = (arg_restrictions & s2s_domain).is_subset(aligned_g_s2s_parm_dom) + is_in_footprint = (arg_restrictions & s2s_domain).is_subset( + aligned_g_s2s_parm_dom) invdesc.is_in_footprint = is_in_footprint @@ -238,6 +234,7 @@ def build_global_storage_to_sweep_map(kernel, invocation_descriptors, # }}} + # {{{ compute storage bounds def find_var_base_indices_and_shape_from_inames( @@ -248,8 +245,6 @@ def find_var_base_indices_and_shape_from_inames( return zip(*base_indices_and_sizes) - - def compute_bounds(kernel, domain, subst_name, stor2sweep, primed_sweep_inames, storage_axis_names): @@ -271,9 +266,6 @@ def compute_bounds(kernel, domain, subst_name, stor2sweep, # }}} - - - def get_access_info(kernel, domain, subst_name, storage_axis_names, storage_axis_sources, sweep_inames, invocation_descriptors): @@ -291,7 +283,8 @@ def get_access_info(kernel, domain, subst_name, primed_sweep_inames) prime_sweep_inames = SubstitutionMapper(make_subst_func( - dict((sin, var(psin)) for sin, psin in zip(sweep_inames, primed_sweep_inames)))) + dict((sin, var(psin)) + for sin, psin in zip(sweep_inames, primed_sweep_inames)))) # }}} @@ -367,9 +360,6 @@ def get_access_info(kernel, domain, subst_name, storage_base_indices, non1_storage_base_indices, non1_storage_shape) - - - def simplify_via_aff(expr): from loopy.symbolic import aff_from_expr, aff_to_expr deps = get_dependencies(expr) @@ -378,8 +368,6 @@ def simplify_via_aff(expr): expr)) - - class InvocationGatherer(ExpandingIdentityMapper): def __init__(self, kernel, subst_name, subst_tag, within): ExpandingIdentityMapper.__init__(self, @@ -435,9 +423,7 @@ class InvocationGatherer(ExpandingIdentityMapper): args=[arg_context[arg_name] for arg_name in rule.arguments], expansion_stack=expn_state.stack)) - return 0 # exact value irrelevant - - + return 0 # exact value irrelevant class InvocationReplacer(ExpandingIdentityMapper): @@ -581,9 +567,12 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, within its arguments. A new, dedicated storage axis is allocated for such an axis. - :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept. - :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes. - :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. + :arg sweep_inames: A :class:`list` of inames and/or rule argument + names to be swept. + :arg storage_axes: A :class:`list` of inames and/or rule argument + names/indices to be used as storage axes. + :arg within: a stack match as understood by + :func:`loopy.context_matching.parse_stack_match`. If `storage_axes` is not specified, it defaults to the arrangement `<direct sweep axes><arguments>` with the direct sweep axes being the @@ -667,8 +656,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, invocation_descriptors.append( InvocationDescriptor(args=args, expands_footprint=True, - expansion_stack=None, - )) + expansion_stack=None)) # }}} @@ -725,7 +713,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, submap(subst.expression, insn_id=None)) & kernel.all_inames() if value_inames - expanding_usage_arg_deps < extra_storage_axes: raise RuntimeError("unreferenced sweep inames specified: " - + ", ".join(extra_storage_axes - value_inames - expanding_usage_arg_deps)) + + ", ".join(extra_storage_axes + - value_inames - expanding_usage_arg_deps)) new_iname_to_tag = {} @@ -737,7 +726,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, expr_subst_dict = {} storage_axis_names = [] - storage_axis_sources = [] # number for arg#, or iname + storage_axis_sources = [] # number for arg#, or iname for i, saxis in enumerate(storage_axes): tag_lookup_saxis = saxis @@ -882,7 +871,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # }}} - kernel = kernel.copy( + kernel = kernel.copy( domains=domch.get_domains_with(new_domain), instructions=[compute_insn] + kernel.instructions, temporary_variables=new_temporary_variables)