From 9f9219e7ca0ff78dae375fec51c0e36669186db8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sat, 17 Oct 2015 23:48:19 -0500 Subject: [PATCH] Allow setting suffixes for disambiguation of redundant names in fused kernels --- loopy/fusion.py | 90 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 17 deletions(-) diff --git a/loopy/fusion.py b/loopy/fusion.py index 8845951ea..82ed53edc 100644 --- a/loopy/fusion.py +++ b/loopy/fusion.py @@ -32,6 +32,44 @@ from loopy.diagnostic import LoopyError from pymbolic import var +def _apply_renames_in_exprs(kernel, var_renames): + from loopy.symbolic import ( + SubstitutionRuleMappingContext, + RuleAwareSubstitutionMapper) + from pymbolic.mapper.substitutor import make_subst_func + from loopy.context_matching import parse_stack_match + + srmc = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + subst_map = RuleAwareSubstitutionMapper( + srmc, make_subst_func(var_renames), + within=parse_stack_match(None)) + return subst_map.map_kernel(kernel) + + +def _rename_temporaries(kernel, suffix, all_identifiers): + var_renames = {} + + vng = kernel.get_var_name_generator() + + new_temporaries = {} + for tv in six.itervalues(kernel.temporary_variables): + if tv.name in all_identifiers: + new_tv_name = vng(tv.name+suffix) + else: + new_tv_name = tv.name + + if new_tv_name != tv.name: + var_renames[tv.name] = var(new_tv_name) + + assert new_tv_name not in new_temporaries + new_temporaries[new_tv_name] = tv.copy(name=new_tv_name) + + kernel = kernel.copy(temporary_variables=new_temporaries) + + return _apply_renames_in_exprs(kernel, var_renames) + + def _find_fusable_loop_domain_index(domain, other_domains): my_inames = set(domain.get_var_dict(dim_type.set)) @@ -168,22 +206,7 @@ def _fuse_two_kernels(knla, knlb): # }}} - # {{{ apply renames in kernel b - - from loopy.symbolic import ( - SubstitutionRuleMappingContext, - RuleAwareSubstitutionMapper) - from pymbolic.mapper.substitutor import make_subst_func - from loopy.context_matching import parse_stack_match - - srmc = SubstitutionRuleMappingContext( - knlb.substitutions, knlb.get_var_name_generator()) - subst_map = RuleAwareSubstitutionMapper( - srmc, make_subst_func(b_var_renames), - within=parse_stack_match(None)) - knlb = subst_map.map_kernel(knlb) - - # }}} + knlb = _apply_renames_in_exprs(knlb, b_var_renames) # {{{ fuse instructions @@ -286,8 +309,41 @@ def _fuse_two_kernels(knla, knlb): # }}} -def fuse_kernels(kernels): +def fuse_kernels(kernels, suffixes=None): kernels = list(kernels) + suffixes = list(suffixes) + + if suffixes: + if len(suffixes) != len(kernels): + raise ValueError("length of 'suffixes' must match " + "length of 'kernels'") + + # {{{ rename temporaries with suffixes + + all_identifiers = [ + kernel.all_variable_names() + for kernel in kernels] + + from functools import reduce, partial + from operator import or_ + merge_sets = partial(reduce, or_) + + new_kernels = [] + for i, (kernel, suffix) in enumerate(zip(kernels, suffixes)): + new_kernels.append( + _rename_temporaries( + kernel, + suffix, + merge_sets( + all_identifiers[:i] + + + all_identifiers[i+1:]))) + + kernels = new_kernels + del new_kernels + + # }}} + result = kernels.pop(0) while kernels: result = _fuse_two_kernels(result, kernels.pop(0)) -- GitLab