Skip to content
Snippets Groups Projects
Commit bda23036 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make join_inames use ExpandingIdentityMapper.

parent c58566b4
No related branches found
No related tags found
No related merge requests found
......@@ -50,8 +50,8 @@ To-do
- ExpandingIdentityMapper
extract_subst -> needs WalkMapper
join_inames
padding
join_inames [DONE]
duplicate_inames [DONE]
split_iname [DONE]
CSE [DONE]
......@@ -70,8 +70,6 @@ To-do
- write_image()
- change_arg_to_image (test!)
- Import SEM test
- Make tests run on GPUs
Fixes:
......
......@@ -244,11 +244,15 @@ split_dimension = MovedFunctionDeprecationWrapper(split_iname)
# {{{ join inames
def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
def join_inames(kernel, inames, new_iname=None, tag=None, within=None):
"""
:arg inames: fastest varying last
:arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`.
"""
from loopy.context_matching import parse_stack_match
within = parse_stack_match(within)
# now fastest varying first
inames = inames[::-1]
......@@ -310,10 +314,6 @@ def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
new_domain = new_domain.eliminate(iname_dt, iname_idx, 1)
new_domain = new_domain.remove_dims(iname_dt, iname_idx, 1)
from loopy.symbolic import SubstitutionMapper
from pymbolic.mapper.substitutor import make_subst_func
subst_map = SubstitutionMapper(make_subst_func(subst_dict))
def subst_forced_iname_deps(fid):
result = set()
for iname in fid:
......@@ -326,20 +326,28 @@ def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
new_insns = [
insn.copy(
assignee=subst_map(insn.assignee),
expression=subst_map(insn.expression),
forced_iname_deps=subst_forced_iname_deps(insn.forced_iname_deps))
for insn in kernel.instructions]
result = (kernel
.map_expressions(subst_map, exclude_instructions=True)
kernel = (kernel
.copy(
instructions=new_insns,
domains=domch.get_domains_with(new_domain),
applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_map]
applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict]
))
return tag_inames(result, {new_iname: tag})
from loopy.symbolic import ExpandingSubstitutionMapper
from pymbolic.mapper.substitutor import make_subst_func
subst_map = ExpandingSubstitutionMapper(
kernel.substitutions, kernel.get_var_name_generator(),
make_subst_func(subst_dict), within)
kernel = subst_map.map_kernel(kernel)
if tag is not None:
kernel = tag_inames(kernel, {new_iname: tag})
return kernel
join_dimensions = MovedFunctionDeprecationWrapper(join_inames)
......
......@@ -163,8 +163,11 @@ def parse_stack_match(smatch):
lowest < next < ... < highest
where `lowest` is necessarily the bottom of the stack. There is currently
no way to anchor to the top of the stack.
where `lowest` is necessarily the bottom of the stack. `...` matches an
arbitrary number of intervening stack levels. There is currently no way to
match the top of the stack.
Each of the entries is an identifier match as understood by :func:`parse_id_match`.
"""
if isinstance(smatch, StackMatchBase):
......
......@@ -385,8 +385,8 @@ class InvocationGatherer(ExpandingIdentityMapper):
ExpandingIdentityMapper.__init__(self,
kernel.substitutions, kernel.get_var_name_generator())
from loopy.symbolic import ParametrizedSubstitutor
self.subst_expander = ParametrizedSubstitutor(
from loopy.symbolic import SubstitutionRuleExpander
self.subst_expander = SubstitutionRuleExpander(
kernel.substitutions)
self.kernel = kernel
......@@ -449,8 +449,8 @@ class InvocationReplacer(ExpandingIdentityMapper):
ExpandingIdentityMapper.__init__(self,
kernel.substitutions, kernel.get_var_name_generator())
from loopy.symbolic import ParametrizedSubstitutor
self.subst_expander = ParametrizedSubstitutor(
from loopy.symbolic import SubstitutionRuleExpander
self.subst_expander = SubstitutionRuleExpander(
kernel.substitutions, kernel.get_var_name_generator())
self.kernel = kernel
......@@ -714,8 +714,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
# extra axes made necessary because they don't occur in the arguments
extra_storage_axes = sweep_inames_set - expanding_usage_arg_deps
from loopy.symbolic import ParametrizedSubstitutor
submap = ParametrizedSubstitutor(kernel.substitutions)
from loopy.symbolic import SubstitutionRuleExpander
submap = SubstitutionRuleExpander(kernel.substitutions)
value_inames = get_dependencies(
submap(subst.expression, insn_id=None)) & kernel.all_inames()
......
......@@ -173,9 +173,9 @@ def extract_subst(kernel, subst_name, template, parameters):
def expand_subst(kernel, ctx_match=None):
from loopy.symbolic import ParametrizedSubstitutor
from loopy.symbolic import SubstitutionRuleExpander
from loopy.context_matching import parse_stack_match
submap = ParametrizedSubstitutor(kernel.substitutions,
submap = SubstitutionRuleExpander(kernel.substitutions,
kernel.get_var_name_generator(),
parse_stack_match(ctx_match))
......
......@@ -474,11 +474,25 @@ class ExpandingIdentityMapper(IdentityMapper):
substitutions=new_substs,
instructions=rename_subst_rules_in_instructions(new_insns, renames))
class ExpandingSubstitutionMapper(ExpandingIdentityMapper):
def __init__(self, rules, make_unique_var_name, subst_func, within):
ExpandingIdentityMapper.__init__(self, rules, make_unique_var_name)
self.subst_func = subst_func
self.within = within
def map_variable(self, expr, expn_state):
result = self.subst_func(expr)
if result is not None or not self.within(expn_state.stack):
return result
else:
return ExpandingIdentityMapper.map_variable(self, expr, expn_state)
# }}}
# {{{ parametrized substitutor
# {{{ substitution rule expander
class ParametrizedSubstitutor(ExpandingIdentityMapper):
class SubstitutionRuleExpander(ExpandingIdentityMapper):
def __init__(self, rules, make_unique_var=None, ctx_match=None):
ExpandingIdentityMapper.__init__(self, rules, make_unique_var)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment