Skip to content
Snippets Groups Projects
Commit 9f9219e7 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Allow setting suffixes for disambiguation of redundant names in fused kernels

parent de300e15
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -32,6 +32,44 @@ from loopy.diagnostic import LoopyError ...@@ -32,6 +32,44 @@ from loopy.diagnostic import LoopyError
from pymbolic import var from pymbolic import var
def _apply_renames_in_exprs(kernel, var_renames):
from loopy.symbolic import (
SubstitutionRuleMappingContext,
RuleAwareSubstitutionMapper)
from pymbolic.mapper.substitutor import make_subst_func
from loopy.context_matching import parse_stack_match
srmc = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())
subst_map = RuleAwareSubstitutionMapper(
srmc, make_subst_func(var_renames),
within=parse_stack_match(None))
return subst_map.map_kernel(kernel)
def _rename_temporaries(kernel, suffix, all_identifiers):
var_renames = {}
vng = kernel.get_var_name_generator()
new_temporaries = {}
for tv in six.itervalues(kernel.temporary_variables):
if tv.name in all_identifiers:
new_tv_name = vng(tv.name+suffix)
else:
new_tv_name = tv.name
if new_tv_name != tv.name:
var_renames[tv.name] = var(new_tv_name)
assert new_tv_name not in new_temporaries
new_temporaries[new_tv_name] = tv.copy(name=new_tv_name)
kernel = kernel.copy(temporary_variables=new_temporaries)
return _apply_renames_in_exprs(kernel, var_renames)
def _find_fusable_loop_domain_index(domain, other_domains): def _find_fusable_loop_domain_index(domain, other_domains):
my_inames = set(domain.get_var_dict(dim_type.set)) my_inames = set(domain.get_var_dict(dim_type.set))
...@@ -168,22 +206,7 @@ def _fuse_two_kernels(knla, knlb): ...@@ -168,22 +206,7 @@ def _fuse_two_kernels(knla, knlb):
# }}} # }}}
# {{{ apply renames in kernel b knlb = _apply_renames_in_exprs(knlb, b_var_renames)
from loopy.symbolic import (
SubstitutionRuleMappingContext,
RuleAwareSubstitutionMapper)
from pymbolic.mapper.substitutor import make_subst_func
from loopy.context_matching import parse_stack_match
srmc = SubstitutionRuleMappingContext(
knlb.substitutions, knlb.get_var_name_generator())
subst_map = RuleAwareSubstitutionMapper(
srmc, make_subst_func(b_var_renames),
within=parse_stack_match(None))
knlb = subst_map.map_kernel(knlb)
# }}}
# {{{ fuse instructions # {{{ fuse instructions
...@@ -286,8 +309,41 @@ def _fuse_two_kernels(knla, knlb): ...@@ -286,8 +309,41 @@ def _fuse_two_kernels(knla, knlb):
# }}} # }}}
def fuse_kernels(kernels): def fuse_kernels(kernels, suffixes=None):
kernels = list(kernels) kernels = list(kernels)
suffixes = list(suffixes)
if suffixes:
if len(suffixes) != len(kernels):
raise ValueError("length of 'suffixes' must match "
"length of 'kernels'")
# {{{ rename temporaries with suffixes
all_identifiers = [
kernel.all_variable_names()
for kernel in kernels]
from functools import reduce, partial
from operator import or_
merge_sets = partial(reduce, or_)
new_kernels = []
for i, (kernel, suffix) in enumerate(zip(kernels, suffixes)):
new_kernels.append(
_rename_temporaries(
kernel,
suffix,
merge_sets(
all_identifiers[:i]
+
all_identifiers[i+1:])))
kernels = new_kernels
del new_kernels
# }}}
result = kernels.pop(0) result = kernels.pop(0)
while kernels: while kernels:
result = _fuse_two_kernels(result, kernels.pop(0)) result = _fuse_two_kernels(result, kernels.pop(0))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment