From bfa4f42728f40d8e62bf0c835acb2de638c38111 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 5 Nov 2012 17:19:11 -0500 Subject: [PATCH] Fix join_inames. --- MEMO | 1 + loopy/__init__.py | 80 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/MEMO b/MEMO index 048223cc9..2ce9d4efc 100644 --- a/MEMO +++ b/MEMO @@ -51,6 +51,7 @@ To-do - ExpandingIdentityMapper extract_subst -> needs WalkMapper padding + replace make_unique_var_name join_inames [DONE] duplicate_inames [DONE] split_iname [DONE] diff --git a/loopy/__init__.py b/loopy/__init__.py index 922c4400f..9171cd6d9 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -39,7 +39,7 @@ from islpy import dim_type from pytools import MovedFunctionDeprecationWrapper -from loopy.symbolic import ExpandingIdentityMapper +from loopy.symbolic import ExpandingIdentityMapper, ExpandingSubstitutionMapper @@ -136,9 +136,6 @@ def split_iname(kernel, split_iname, inner_length, :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ - from loopy.context_matching import parse_stack_match - within = parse_stack_match(within) - existing_tag = kernel.iname_to_tag.get(split_iname) from loopy.kernel import ForceSequentialTag if do_tagged_check and ( @@ -151,12 +148,12 @@ def split_iname(kernel, split_iname, inner_length, applied_iname_rewrites = kernel.applied_iname_rewrites[:] + vng = kernel.get_var_name_generator() + if outer_iname is None: - outer_iname = kernel.make_unique_var_name( - split_iname+"_outer") + outer_iname = vng(split_iname+"_outer") if inner_iname is None: - inner_iname = kernel.make_unique_var_name( - split_iname+"_inner") + inner_iname = vng(split_iname+"_inner") def process_set(s): var_dict = s.get_var_dict() @@ -183,11 +180,15 @@ def split_iname(kernel, split_iname, inner_length, space, {split_iname:1, inner_iname: -1, outer_iname:-inner_length}))) name_dim_type, name_idx = space.get_var_dict()[split_iname] - return (s - .intersect(inner_constraint_set) + s = s.intersect(inner_constraint_set) + + if within is None: + s = (s .eliminate(name_dim_type, name_idx, 1) .remove_dims(name_dim_type, name_idx, 1)) + return s + new_domains = [process_set(dom) for dom in kernel.domains] from pymbolic import var @@ -198,7 +199,7 @@ def split_iname(kernel, split_iname, inner_length, subst_map = {var(split_iname): new_loop_index} applied_iname_rewrites.append(subst_map) - # {{{ actually modify instructions + # {{{ update forced_iname deps new_insns = [] for insn in kernel.instructions: @@ -227,6 +228,9 @@ def split_iname(kernel, split_iname, inner_length, applied_iname_rewrites=applied_iname_rewrites, )) + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + ins = _InameSplitter(kernel, within, split_iname, outer_iname, inner_iname, new_loop_index) @@ -244,20 +248,48 @@ split_dimension = MovedFunctionDeprecationWrapper(split_iname) # {{{ join inames +class _InameJoiner(ExpandingSubstitutionMapper): + def __init__(self, kernel, within, subst_func, joined_inames, new_iname): + ExpandingSubstitutionMapper.__init__(self, + kernel.substitutions, kernel.get_var_name_generator(), + subst_func, within) + + self.joined_inames = set(joined_inames) + self.new_iname = new_iname + + def map_reduction(self, expr, expn_state): + expr_inames = set(expr.inames) + overlap = self.join_inames & expr_inames + if overlap and self.within(expn_state.stack): + if overlap != expr_inames: + raise RuntimeError( + "Cannot join inames '%s' if there is a reduction " + "that does not use all of the inames being joined. " + "(Found one with just '%s'.)" + % ( + ", ".join(self.joined_inames), + ", ".join(expr_inames))) + + new_inames = expr_inames - self.joined_inames + new_inames.add(self.new_iname) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state)) + else: + return ExpandingIdentityMapper.map_reduction(self, expr, expn_state) + def join_inames(kernel, inames, new_iname=None, tag=None, within=None): """ :arg inames: fastest varying last :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ - from loopy.context_matching import parse_stack_match - within = parse_stack_match(within) - # now fastest varying first inames = inames[::-1] if new_iname is None: - new_iname = kernel.make_unique_var_name("_and_".join(inames)) + new_iname = kernel.get_var_name_generator()("_and_".join(inames)) from loopy.kernel import DomainChanger domch = DomainChanger(kernel, frozenset(inames)) @@ -311,8 +343,10 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): for i, iname in enumerate(inames): iname_to_dim = new_domain.get_space().get_var_dict() iname_dt, iname_idx = iname_to_dim[iname] - new_domain = new_domain.eliminate(iname_dt, iname_idx, 1) - new_domain = new_domain.remove_dims(iname_dt, iname_idx, 1) + + if within is None: + new_domain = new_domain.eliminate(iname_dt, iname_idx, 1) + new_domain = new_domain.remove_dims(iname_dt, iname_idx, 1) def subst_forced_iname_deps(fid): result = set() @@ -336,13 +370,15 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) - from loopy.symbolic import ExpandingSubstitutionMapper + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + from pymbolic.mapper.substitutor import make_subst_func - subst_map = ExpandingSubstitutionMapper( - kernel.substitutions, kernel.get_var_name_generator(), - make_subst_func(subst_dict), within) + ijoin = _InameJoiner(kernel, within, + make_subst_func(subst_dict), + inames, new_iname) - kernel = subst_map.map_kernel(kernel) + kernel = ijoin.map_kernel(kernel) if tag is not None: kernel = tag_inames(kernel, {new_iname: tag}) -- GitLab