From f5ed460dcb98b5cb18848725175f4191df7c12d7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 13 Apr 2015 22:45:53 -0500 Subject: [PATCH] Factor array_buffer out of precompute --- loopy/array_buffer.py | 408 ++++++++++++++++++++++++++++++++++++ loopy/precompute.py | 476 ++++++++---------------------------------- 2 files changed, 496 insertions(+), 388 deletions(-) create mode 100644 loopy/array_buffer.py diff --git a/loopy/array_buffer.py b/loopy/array_buffer.py new file mode 100644 index 000000000..c15935b73 --- /dev/null +++ b/loopy/array_buffer.py @@ -0,0 +1,408 @@ +from __future__ import division, absolute_import +from six.moves import range, zip + +__copyright__ = "Copyright (C) 2012-2015 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl +from islpy import dim_type +from loopy.symbolic import (get_dependencies, SubstitutionMapper) +from pymbolic.mapper.substitutor import make_subst_func + +from pytools import Record +from pymbolic import var + + +class AccessDescriptor(Record): + """ + .. attribute:: identifier + + An identifier under user control, used to connect this access descriptor + to the access that generated it. Any Python value. + """ + + __slots__ = [ + "identifier", + "expands_footprint", + "storage_axis_exprs", + ] + + +def to_parameters_or_project_out(param_inames, set_inames, set): + for iname in list(set.get_space().get_var_dict().keys()): + if iname in param_inames: + dt, idx = set.get_space().get_var_dict()[iname] + set = set.move_dims( + dim_type.param, set.dim(dim_type.param), + dt, idx, 1) + elif iname in set_inames: + pass + else: + dt, idx = set.get_space().get_var_dict()[iname] + set = set.project_out(dt, idx, 1) + + return set + + +# {{{ construct storage->sweep map + +def build_per_access_storage_to_domain_map(accdesc, domain, + storage_axis_names, + prime_sweep_inames): + + map_space = domain.space + stor_dim = len(storage_axis_names) + rn = map_space.dim(dim_type.out) + + map_space = map_space.add_dims(dim_type.in_, stor_dim) + for i, saxis in enumerate(storage_axis_names): + # arg names are initially primed, to be replaced with unprimed + # base-0 versions below + + map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'") + + # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn) + + set_space = map_space.move_dims( + dim_type.out, rn, + dim_type.in_, 0, stor_dim).range() + + # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes'] + + stor2sweep = None + + from loopy.symbolic import aff_from_expr + + for saxis, sa_expr in zip(storage_axis_names, accdesc.storage_axis_exprs): + cns = isl.Constraint.equality_from_aff( + aff_from_expr(set_space, + var(saxis+"'") - prime_sweep_inames(sa_expr))) + + cns_map = isl.BasicMap.from_constraint(cns) + if stor2sweep is None: + stor2sweep = cns_map + else: + stor2sweep = stor2sweep.intersect(cns_map) + + if stor2sweep is not None: + stor2sweep = stor2sweep.move_dims( + dim_type.in_, 0, + dim_type.out, rn, stor_dim) + + # stor2sweep is back in map_space + return stor2sweep + + +def move_to_par_from_out(s2smap, except_inames): + while True: + var_dict = s2smap.get_var_dict(dim_type.out) + todo_inames = set(var_dict) - except_inames + if todo_inames: + iname = todo_inames.pop() + + _, dim_idx = var_dict[iname] + s2smap = s2smap.move_dims( + dim_type.param, s2smap.dim(dim_type.param), + dim_type.out, dim_idx, 1) + else: + return s2smap + + +def build_global_storage_to_sweep_map(kernel, access_descriptors, + domain_dup_sweep, dup_sweep_index, + storage_axis_names, + sweep_inames, primed_sweep_inames, prime_sweep_inames): + # The storage map goes from storage axes to the domain. + # The first len(arg_names) storage dimensions are the rule's arguments. + + global_stor2sweep = None + + # build footprint + for accdesc in access_descriptors: + if accdesc.expands_footprint: + stor2sweep = build_per_access_storage_to_domain_map( + accdesc, domain_dup_sweep, + storage_axis_names, + prime_sweep_inames) + + if global_stor2sweep is None: + global_stor2sweep = stor2sweep + else: + global_stor2sweep = global_stor2sweep.union(stor2sweep) + + if isinstance(global_stor2sweep, isl.BasicMap): + global_stor2sweep = isl.Map.from_basic_map(global_stor2sweep) + global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep) + + # space for global_stor2sweep: + # [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn) + + return global_stor2sweep + +# }}} + + +# {{{ compute storage bounds + +def find_var_base_indices_and_shape_from_inames( + domain, inames, cache_manager, context=None): + base_indices_and_sizes = [ + cache_manager.base_index_and_length(domain, iname, context) + for iname in inames] + return list(zip(*base_indices_and_sizes)) + + +def compute_bounds(kernel, domain, stor2sweep, + primed_sweep_inames, storage_axis_names): + + bounds_footprint_map = move_to_par_from_out( + stor2sweep, except_inames=frozenset(primed_sweep_inames)) + + # compute bounds for each storage axis + storage_domain = bounds_footprint_map.domain().coalesce() + + if not storage_domain.is_bounded(): + raise RuntimeError("sweep did not result in a bounded storage domain") + + return find_var_base_indices_and_shape_from_inames( + storage_domain, [saxis+"'" for saxis in storage_axis_names], + kernel.cache_manager, context=kernel.assumptions) + +# }}} + + +# {{{ array-to-buffer map + +class ArrayToBufferMap(object): + def __init__(self, kernel, domain, sweep_inames, access_descriptors, + storage_axis_count): + self.kernel = kernel + self.sweep_inames = sweep_inames + + storage_axis_names = self.storage_axis_names = [ + "_loopy_storage_%d" % i for i in range(storage_axis_count)] + + # {{{ duplicate sweep inames + + # The duplication is necessary, otherwise the storage fetch + # inames remain weirdly tied to the original sweep inames. + + self.primed_sweep_inames = [psin+"'" for psin in sweep_inames] + + from loopy.isl_helpers import duplicate_axes + dup_sweep_index = domain.space.dim(dim_type.out) + domain_dup_sweep = duplicate_axes( + domain, sweep_inames, + self.primed_sweep_inames) + + self.prime_sweep_inames = SubstitutionMapper(make_subst_func( + dict((sin, var(psin)) + for sin, psin in zip(sweep_inames, self.primed_sweep_inames)))) + + # # }}} + + self.stor2sweep = build_global_storage_to_sweep_map( + kernel, access_descriptors, + domain_dup_sweep, dup_sweep_index, + storage_axis_names, + sweep_inames, self.primed_sweep_inames, self.prime_sweep_inames) + + storage_base_indices, storage_shape = compute_bounds( + kernel, domain, self.stor2sweep, self.primed_sweep_inames, + storage_axis_names) + + # compute augmented domain + + # {{{ filter out unit-length dimensions + + non1_storage_axis_flags = [] + non1_storage_shape = [] + + for saxis, bi, l in zip( + storage_axis_names, storage_base_indices, storage_shape): + has_length_non1 = l != 1 + + non1_storage_axis_flags.append(has_length_non1) + + if has_length_non1: + non1_storage_shape.append(l) + + # }}} + + # {{{ subtract off the base indices + # add the new, base-0 indices as new in dimensions + + sp = self.stor2sweep.get_space() + stor_idx = sp.dim(dim_type.out) + + n_stor = storage_axis_count + nn1_stor = len(non1_storage_shape) + + aug_domain = self.stor2sweep.move_dims( + dim_type.out, stor_idx, + dim_type.in_, 0, + n_stor).range() + + # aug_domain space now: + # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'] + + aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor) + + inew = 0 + for i, name in enumerate(storage_axis_names): + if non1_storage_axis_flags[i]: + aug_domain = aug_domain.set_dim_name( + dim_type.set, stor_idx + inew, name) + inew += 1 + + # aug_domain space now: + # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes] + + from loopy.symbolic import aff_from_expr + for saxis, bi, s in zip(storage_axis_names, storage_base_indices, + storage_shape): + if s != 1: + cns = isl.Constraint.equality_from_aff( + aff_from_expr(aug_domain.get_space(), + var(saxis) - (var(saxis+"'") - bi))) + + aug_domain = aug_domain.add_constraint(cns) + + # }}} + + # eliminate (primed) storage axes with non-zero base indices + aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor) + + # eliminate duplicated sweep_inames + nsweep = len(sweep_inames) + aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep) + + self.non1_storage_axis_flags = non1_storage_axis_flags + self.aug_domain = aug_domain + self.storage_base_indices = storage_base_indices + self.non1_storage_shape = non1_storage_shape + + def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names, + boxify_sweep=False): + + renamed_aug_domain = self.aug_domain + first_storage_index = ( + renamed_aug_domain.dim(dim_type.set) + - len(self.non1_storage_shape)) + + inon1 = 0 + for i, old_name in enumerate(self.storage_axis_names): + if not self.non1_storage_axis_flags[i]: + continue + + new_name = new_non1_storage_axis_names[inon1] + + assert ( + renamed_aug_domain.get_dim_name( + dim_type.set, first_storage_index+inon1) + == old_name) + renamed_aug_domain = renamed_aug_domain.set_dim_name( + dim_type.set, first_storage_index+inon1, new_name) + + inon1 += 1 + + domain, renamed_aug_domain = isl.align_two(domain, renamed_aug_domain) + domain = domain & renamed_aug_domain + + from loopy.isl_helpers import convexify, boxify + if boxify_sweep: + return boxify(self.kernel.cache_manager, domain, + new_non1_storage_axis_names, self.kernel.assumptions) + else: + return convexify(domain) + + def is_access_descriptor_in_footprint(self, accdesc): + if accdesc.expands_footprint: + return True + + # Make all inames except the sweep parameters. (The footprint may depend on + # those.) (I.e. only leave sweep inames as out parameters.) + + global_s2s_par_dom = move_to_par_from_out( + self.stor2sweep, + except_inames=frozenset(self.primed_sweep_inames)).domain() + + arg_inames = ( + set(global_s2s_par_dom.get_var_names(dim_type.param)) + & self.kernel.all_inames()) + + for arg in accdesc.args: + arg_inames.update(get_dependencies(arg)) + arg_inames = frozenset(arg_inames) + + from loopy.kernel import CannotBranchDomainTree + try: + usage_domain = self.kernel.get_inames_domain(arg_inames) + except CannotBranchDomainTree: + return False + + for i in range(usage_domain.dim(dim_type.set)): + iname = usage_domain.get_dim_name(dim_type.set, i) + if iname in self.sweep_inames: + usage_domain = usage_domain.set_dim_name( + dim_type.set, i, iname+"'") + + stor2sweep = build_per_access_storage_to_domain_map(accdesc, + usage_domain, self.storage_axis_names, + self.prime_sweep_inames) + + if isinstance(stor2sweep, isl.BasicMap): + stor2sweep = isl.Map.from_basic_map(stor2sweep) + + stor2sweep = stor2sweep.intersect_range(usage_domain) + + stor2sweep = move_to_par_from_out(stor2sweep, + except_inames=frozenset(self.primed_sweep_inames)) + + s2s_domain = stor2sweep.domain() + s2s_domain, aligned_g_s2s_parm_dom = isl.align_two( + s2s_domain, global_s2s_par_dom) + + arg_restrictions = ( + aligned_g_s2s_parm_dom + .eliminate(dim_type.set, 0, + aligned_g_s2s_parm_dom.dim(dim_type.set)) + .remove_divs()) + + return (arg_restrictions & s2s_domain).is_subset( + aligned_g_s2s_parm_dom) + + +class NoOpArrayToBufferMap(object): + non1_storage_axis_names = () + storage_base_indices = () + non1_storage_shape = () + + def is_access_descriptor_in_footprint(self, accdesc): + # no index dependencies--every reference to the subst rule + # is necessarily in the footprint. + + return True + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/precompute.py b/loopy/precompute.py index 8bd049b12..de5cd9e02 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -1,8 +1,6 @@ -from __future__ import division -from __future__ import absolute_import +from __future__ import division, absolute_import import six -from six.moves import range -from six.moves import zip +from six.moves import range, zip __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -28,337 +26,35 @@ THE SOFTWARE. import islpy as isl -from islpy import dim_type from loopy.symbolic import (get_dependencies, SubstitutionMapper, ExpandingIdentityMapper) from pymbolic.mapper.substitutor import make_subst_func import numpy as np -from pytools import Record from pymbolic import var +from loopy.array_buffer import (ArrayToBufferMap, NoOpArrayToBufferMap, + AccessDescriptor) -class InvocationDescriptor(Record): - __slots__ = [ - "args", - "expands_footprint", - "is_in_footprint", - # Remember where the invocation happened, in terms of the expansion - # call stack. - "expansion_stack", - ] +class RuleAccessDescriptor(AccessDescriptor): + __slots__ = ["args", "expansion_stack"] -def to_parameters_or_project_out(param_inames, set_inames, set): - for iname in list(set.get_space().get_var_dict().keys()): - if iname in param_inames: - dt, idx = set.get_space().get_var_dict()[iname] - set = set.move_dims( - dim_type.param, set.dim(dim_type.param), - dt, idx, 1) - elif iname in set_inames: - pass - else: - dt, idx = set.get_space().get_var_dict()[iname] - set = set.project_out(dt, idx, 1) - - return set - - -# {{{ construct storage->sweep map - -def build_per_access_storage_to_domain_map(invdesc, domain, - storage_axis_names, storage_axis_sources, - prime_sweep_inames): - - map_space = domain.get_space() - stor_dim = len(storage_axis_names) - rn = map_space.dim(dim_type.out) - - map_space = map_space.add_dims(dim_type.in_, stor_dim) - for i, saxis in enumerate(storage_axis_names): - # arg names are initially primed, to be replaced with unprimed - # base-0 versions below - - map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'") +def access_descriptor_id(args, expansion_stack): + return (args, expansion_stack) - # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn) - set_space = map_space.move_dims( - dim_type.out, rn, - dim_type.in_, 0, stor_dim).range() +def storage_axis_exprs(storage_axis_sources, args): + result = [] - # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes'] - - stor2sweep = None - - from loopy.symbolic import aff_from_expr - - for saxis, saxis_source in zip(storage_axis_names, storage_axis_sources): + for saxis_source in storage_axis_sources: if isinstance(saxis_source, int): - # an argument - cns = isl.Constraint.equality_from_aff( - aff_from_expr(set_space, - var(saxis+"'") - - prime_sweep_inames(invdesc.args[saxis_source]))) - else: - # a 'bare' sweep iname - cns = isl.Constraint.equality_from_aff( - aff_from_expr(set_space, - var(saxis+"'") - - prime_sweep_inames(var(saxis_source)))) - - cns_map = isl.BasicMap.from_constraint(cns) - if stor2sweep is None: - stor2sweep = cns_map + result.append(args[saxis_source]) else: - stor2sweep = stor2sweep.intersect(cns_map) - - if stor2sweep is not None: - stor2sweep = stor2sweep.move_dims( - dim_type.in_, 0, - dim_type.out, rn, stor_dim) - - # stor2sweep is back in map_space - return stor2sweep - - -def move_to_par_from_out(s2smap, except_inames): - while True: - var_dict = s2smap.get_var_dict(dim_type.out) - todo_inames = set(var_dict) - except_inames - if todo_inames: - iname = todo_inames.pop() - - _, dim_idx = var_dict[iname] - s2smap = s2smap.move_dims( - dim_type.param, s2smap.dim(dim_type.param), - dim_type.out, dim_idx, 1) - else: - return s2smap - - -def build_global_storage_to_sweep_map(kernel, invocation_descriptors, - domain_dup_sweep, dup_sweep_index, - storage_axis_names, storage_axis_sources, - sweep_inames, primed_sweep_inames, prime_sweep_inames): - """ - As a side effect, this fills out is_in_footprint in the - invocation descriptors. - """ - - # The storage map goes from storage axes to the domain. - # The first len(arg_names) storage dimensions are the rule's arguments. - - global_stor2sweep = None - - # build footprint - for invdesc in invocation_descriptors: - if invdesc.expands_footprint: - stor2sweep = build_per_access_storage_to_domain_map( - invdesc, domain_dup_sweep, - storage_axis_names, storage_axis_sources, - prime_sweep_inames) - - if global_stor2sweep is None: - global_stor2sweep = stor2sweep - else: - global_stor2sweep = global_stor2sweep.union(stor2sweep) - - invdesc.is_in_footprint = True - - if isinstance(global_stor2sweep, isl.BasicMap): - global_stor2sweep = isl.Map.from_basic_map(global_stor2sweep) - global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep) - - # space for global_stor2sweep: - # [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn) - - # {{{ check if non-footprint-building invocation descriptors fall into footprint - - # Make all inames except the sweep parameters. (The footprint may depend on - # those.) (I.e. only leave sweep inames as out parameters.) - global_s2s_par_dom = move_to_par_from_out( - global_stor2sweep, except_inames=frozenset(primed_sweep_inames)).domain() - - for invdesc in invocation_descriptors: - if not invdesc.expands_footprint: - arg_inames = ( - set(global_s2s_par_dom.get_var_names(dim_type.param)) - & kernel.all_inames()) - - for arg in invdesc.args: - arg_inames.update(get_dependencies(arg)) - arg_inames = frozenset(arg_inames) - - from loopy.kernel import CannotBranchDomainTree - try: - usage_domain = kernel.get_inames_domain(arg_inames) - except CannotBranchDomainTree: - # and that's the end of that. - invdesc.is_in_footprint = False - continue - - for i in range(usage_domain.dim(dim_type.set)): - iname = usage_domain.get_dim_name(dim_type.set, i) - if iname in sweep_inames: - usage_domain = usage_domain.set_dim_name( - dim_type.set, i, iname+"'") - - stor2sweep = build_per_access_storage_to_domain_map(invdesc, - usage_domain, storage_axis_names, storage_axis_sources, - prime_sweep_inames) - - if isinstance(stor2sweep, isl.BasicMap): - stor2sweep = isl.Map.from_basic_map(stor2sweep) - - stor2sweep = stor2sweep.intersect_range(usage_domain) - - stor2sweep = move_to_par_from_out(stor2sweep, - except_inames=frozenset(primed_sweep_inames)) - - s2s_domain = stor2sweep.domain() - s2s_domain, aligned_g_s2s_parm_dom = isl.align_two( - s2s_domain, global_s2s_par_dom) - - arg_restrictions = ( - aligned_g_s2s_parm_dom - .eliminate(dim_type.set, 0, - aligned_g_s2s_parm_dom.dim(dim_type.set)) - .remove_divs()) - - is_in_footprint = (arg_restrictions & s2s_domain).is_subset( - aligned_g_s2s_parm_dom) - - invdesc.is_in_footprint = is_in_footprint - - # }}} - - return global_stor2sweep - -# }}} - - -# {{{ compute storage bounds - -def find_var_base_indices_and_shape_from_inames( - domain, inames, cache_manager, context=None): - base_indices_and_sizes = [ - cache_manager.base_index_and_length(domain, iname, context) - for iname in inames] - return list(zip(*base_indices_and_sizes)) - - -def compute_bounds(kernel, domain, stor2sweep, - primed_sweep_inames, storage_axis_names): - - bounds_footprint_map = move_to_par_from_out( - stor2sweep, except_inames=frozenset(primed_sweep_inames)) - - # compute bounds for each storage axis - storage_domain = bounds_footprint_map.domain().coalesce() - - if not storage_domain.is_bounded(): - raise RuntimeError("sweep did not result in a bounded storage domain") - - return find_var_base_indices_and_shape_from_inames( - storage_domain, [saxis+"'" for saxis in storage_axis_names], - kernel.cache_manager, context=kernel.assumptions) + result.append(var(saxis_source)) -# }}} - - -def get_access_info(kernel, domain, - storage_axis_names, storage_axis_sources, - sweep_inames, invocation_descriptors): - - # {{{ duplicate sweep inames - - # The duplication is necessary, otherwise the storage fetch - # inames remain weirdly tied to the original sweep inames. - - primed_sweep_inames = [psin+"'" for psin in sweep_inames] - from loopy.isl_helpers import duplicate_axes - dup_sweep_index = domain.space.dim(dim_type.out) - domain_dup_sweep = duplicate_axes( - domain, sweep_inames, - primed_sweep_inames) - - prime_sweep_inames = SubstitutionMapper(make_subst_func( - dict((sin, var(psin)) - for sin, psin in zip(sweep_inames, primed_sweep_inames)))) - - # }}} - - stor2sweep = build_global_storage_to_sweep_map( - kernel, invocation_descriptors, - domain_dup_sweep, dup_sweep_index, - storage_axis_names, storage_axis_sources, - sweep_inames, primed_sweep_inames, prime_sweep_inames) - - storage_base_indices, storage_shape = compute_bounds( - kernel, domain, stor2sweep, primed_sweep_inames, - storage_axis_names) - - # compute augmented domain - - # {{{ filter out unit-length dimensions - - non1_storage_axis_names = [] - non1_storage_shape = [] - - for saxis, bi, l in zip(storage_axis_names, storage_base_indices, storage_shape): - if l != 1: - non1_storage_axis_names.append(saxis) - non1_storage_shape.append(l) - - # }}} - - # {{{ subtract off the base indices - # add the new, base-0 indices as new in dimensions - - sp = stor2sweep.get_space() - stor_idx = sp.dim(dim_type.out) - - n_stor = len(storage_axis_names) - nn1_stor = len(non1_storage_axis_names) - - aug_domain = stor2sweep.move_dims( - dim_type.out, stor_idx, - dim_type.in_, 0, - n_stor).range() - - # aug_domain space now: - # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'] - - aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor) - for i, name in enumerate(non1_storage_axis_names): - aug_domain = aug_domain.set_dim_name(dim_type.set, stor_idx+i, name) - - # aug_domain space now: - # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes] - - from loopy.symbolic import aff_from_expr - for saxis, bi, s in zip(storage_axis_names, storage_base_indices, storage_shape): - if s != 1: - cns = isl.Constraint.equality_from_aff( - aff_from_expr(aug_domain.get_space(), - var(saxis) - (var(saxis+"'") - bi))) - - aug_domain = aug_domain.add_constraint(cns) - - # }}} - - # eliminate (primed) storage axes with non-zero base indices - aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor) - - # eliminate duplicated sweep_inames - nsweep = len(sweep_inames) - aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep) - - return (non1_storage_axis_names, aug_domain, - storage_base_indices, non1_storage_shape) + return result def simplify_via_aff(expr): @@ -369,7 +65,7 @@ def simplify_via_aff(expr): expr)) -class InvocationGatherer(ExpandingIdentityMapper): +class RuleInvocationGatherer(ExpandingIdentityMapper): def __init__(self, kernel, subst_name, subst_tag, within): ExpandingIdentityMapper.__init__(self, kernel.substitutions, kernel.get_var_name_generator()) @@ -383,7 +79,7 @@ class InvocationGatherer(ExpandingIdentityMapper): self.subst_tag = subst_tag self.within = within - self.invocation_descriptors = [] + self.access_descriptors = [] def map_substitution(self, name, tag, arguments, expn_state): process_me = name == self.subst_name @@ -423,17 +119,21 @@ class InvocationGatherer(ExpandingIdentityMapper): return ExpandingIdentityMapper.map_substitution( self, name, tag, arguments, expn_state) - self.invocation_descriptors.append( - InvocationDescriptor( - args=[arg_context[arg_name] for arg_name in rule.arguments], - expansion_stack=expn_state.stack)) + args = [arg_context[arg_name] for arg_name in rule.arguments] + + # Do not set expands_footprint here, it is set below. + self.access_descriptors.append( + RuleAccessDescriptor( + identifier=access_descriptor_id(args, expn_state.stack), + args=args, + )) return 0 # exact value irrelevant -class InvocationReplacer(ExpandingIdentityMapper): +class RuleInvocationReplacer(ExpandingIdentityMapper): def __init__(self, kernel, subst_name, subst_tag, within, - invocation_descriptors, + access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, storage_base_indices, non1_storage_axis_names, target_var_name): @@ -449,7 +149,8 @@ class InvocationReplacer(ExpandingIdentityMapper): self.subst_tag = subst_tag self.within = within - self.invocation_descriptors = invocation_descriptors + self.access_descriptors = access_descriptors + self.array_base_map = array_base_map self.storage_axis_names = storage_axis_names self.storage_axis_sources = storage_axis_sources @@ -477,21 +178,21 @@ class InvocationReplacer(ExpandingIdentityMapper): return ExpandingIdentityMapper.map_substitution( self, name, tag, arguments, expn_state) - matching_invdesc = None - for invdesc in self.invocation_descriptors: - if invdesc.args == args and expn_state.stack: + matching_accdesc = None + for accdesc in self.access_descriptors: + if accdesc.identifier == access_descriptor_id(args, expn_state.stack): # Could be more than one, that's fine. - matching_invdesc = invdesc + matching_accdesc = accdesc break - assert matching_invdesc is not None + assert matching_accdesc is not None - invdesc = matching_invdesc - del matching_invdesc + accdesc = matching_accdesc + del matching_accdesc # }}} - if not invdesc.is_in_footprint: + if not self.array_base_map.is_access_descriptor_in_footprint(accdesc): return ExpandingIdentityMapper.map_substitution( self, name, tag, arguments, expn_state) @@ -646,11 +347,17 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, from loopy.context_matching import parse_stack_match within = parse_stack_match(within) + from loopy.kernel.data import parse_tag + default_tag = parse_tag(default_tag) + + subst = kernel.substitutions[subst_name] + c_subst_name = subst_name.replace(".", "_") + # }}} - # {{{ process invocations in footprint generators, start invocation_descriptors + # {{{ process invocations in footprint generators, start access_descriptors - invocation_descriptors = [] + access_descriptors = [] if footprint_generators: from pymbolic.primitives import Variable, Call @@ -664,35 +371,29 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, raise ValueError("footprint generator must " "be substitution rule invocation") - invocation_descriptors.append( - InvocationDescriptor(args=args, + access_descriptors.append( + RuleAccessDescriptor( + identifier=access_descriptor_id(args, None), expands_footprint=True, - expansion_stack=None)) + args=args + )) # }}} - c_subst_name = subst_name.replace(".", "_") + # {{{ gather up invocations in kernel code, finish access_descriptors - from loopy.kernel.data import parse_tag - default_tag = parse_tag(default_tag) - - subst = kernel.substitutions[subst_name] - arg_names = subst.arguments - - # {{{ gather up invocations in kernel code, finish invocation_descriptors - - invg = InvocationGatherer(kernel, subst_name, subst_tag, within) + invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within) import loopy as lp for insn in kernel.instructions: if isinstance(insn, lp.ExpressionInstruction): invg(insn.expression, insn.id, insn.tags) - for invdesc in invg.invocation_descriptors: - invocation_descriptors.append( - invdesc.copy(expands_footprint=footprint_generators is None)) + for accdesc in invg.access_descriptors: + access_descriptors.append( + accdesc.copy(expands_footprint=footprint_generators is None)) - if not invocation_descriptors: + if not access_descriptors: raise RuntimeError("no invocations of '%s' found" % subst_name) # }}} @@ -704,9 +405,9 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, expanding_usage_arg_deps = set() - for invdesc in invocation_descriptors: - if invdesc.expands_footprint: - for arg in invdesc.args: + for accdesc in access_descriptors: + if accdesc.expands_footprint: + for arg in accdesc.args: expanding_usage_arg_deps.update( get_dependencies(arg) & kernel.all_inames()) @@ -735,7 +436,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, if storage_axes is None: storage_axes = ( list(extra_storage_axes) - + list(range(len(arg_names)))) + + list(range(len(subst.arguments)))) expr_subst_dict = {} @@ -784,6 +485,16 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # }}} + # {{{ fill out access_descriptors[...].storage_axis_exprs + + access_descriptors = [ + accdesc.copy( + storage_axis_exprs=storage_axis_exprs( + storage_axis_sources, accdesc.args)) + for accdesc in access_descriptors] + + # }}} + expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps) assert expanding_inames <= kernel.all_inames() @@ -805,37 +516,26 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # }}} - (non1_storage_axis_names, new_domain, - storage_base_indices, non1_storage_shape) = \ - get_access_info(kernel, domch.domain, - storage_axis_names, storage_axis_sources, - sweep_inames, invocation_descriptors) + abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames, + access_descriptors, len(storage_axis_names)) - from loopy.isl_helpers import convexify, boxify - if fetch_bounding_box: - new_domain = boxify(kernel.cache_manager, new_domain, - non1_storage_axis_names, kernel.assumptions) - else: - new_domain = convexify(new_domain) - - for saxis in storage_axis_names: - if saxis not in non1_storage_axis_names: + non1_storage_axis_names = [] + for i, saxis in enumerate(storage_axis_names): + if abm.non1_storage_axis_flags[i]: + non1_storage_axis_names.append(saxis) + else: del new_iname_to_tag[saxis] - new_kernel_domains = domch.get_domains_with(new_domain) + new_kernel_domains = domch.get_domains_with( + abm.augment_domain_with_sweep( + domch.domain, non1_storage_axis_names, + boxify_sweep=fetch_bounding_box)) + else: # leave kernel domains unchanged new_kernel_domains = kernel.domains - non1_storage_axis_names = () - storage_base_indices = () - non1_storage_shape = () - - # no index dependencies--every reference to the subst rule - # is necessarily in the footprint. - - for invdesc in invocation_descriptors: - invdesc.is_in_footprint = True + abm = NoOpArrayToBufferMap() # {{{ set up compute insn @@ -856,7 +556,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, compute_expr = (SubstitutionMapper( make_subst_func(dict( (arg_name, zero_length_1_arg(arg_name)+bi) - for arg_name, bi in zip(storage_axis_names, storage_base_indices) + for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices) ))) (compute_expr)) @@ -870,10 +570,10 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # {{{ substitute rule into expressions in kernel (if within footprint) - invr = InvocationReplacer(kernel, subst_name, subst_tag, within, - invocation_descriptors, + invr = RuleInvocationReplacer(kernel, subst_name, subst_tag, within, + access_descriptors, abm, storage_axis_names, storage_axis_sources, - storage_base_indices, non1_storage_axis_names, + abm.storage_base_indices, non1_storage_axis_names, target_var_name) kernel = invr.map_kernel(kernel) @@ -897,8 +597,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, temp_var = lp.TemporaryVariable( name=target_var_name, dtype=dtype, - base_indices=(0,)*len(non1_storage_shape), - shape=tuple(non1_storage_shape), + base_indices=(0,)*len(abm.non1_storage_shape), + shape=tuple(abm.non1_storage_shape), is_local=temporary_is_local) new_temporary_variables[target_var_name] = temp_var -- GitLab