diff --git a/MEMO b/MEMO index 63851b21f95d44cf5726fa8dbc90e577db0a33b7..048223cc9c8890b7a049276904ff162140d78a8b 100644 --- a/MEMO +++ b/MEMO @@ -50,8 +50,8 @@ To-do - ExpandingIdentityMapper extract_subst -> needs WalkMapper - join_inames padding + join_inames [DONE] duplicate_inames [DONE] split_iname [DONE] CSE [DONE] @@ -70,8 +70,6 @@ To-do - write_image() - change_arg_to_image (test!) -- Import SEM test - - Make tests run on GPUs Fixes: diff --git a/loopy/__init__.py b/loopy/__init__.py index 80afc44a7055a5daba827c17ca271c7022c26780..922c4400f1bd5cb7cba6ed37ab6dd4136a0dd978 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -244,11 +244,15 @@ split_dimension = MovedFunctionDeprecationWrapper(split_iname) # {{{ join inames -def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()): +def join_inames(kernel, inames, new_iname=None, tag=None, within=None): """ :arg inames: fastest varying last + :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + # now fastest varying first inames = inames[::-1] @@ -310,10 +314,6 @@ def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()): new_domain = new_domain.eliminate(iname_dt, iname_idx, 1) new_domain = new_domain.remove_dims(iname_dt, iname_idx, 1) - from loopy.symbolic import SubstitutionMapper - from pymbolic.mapper.substitutor import make_subst_func - subst_map = SubstitutionMapper(make_subst_func(subst_dict)) - def subst_forced_iname_deps(fid): result = set() for iname in fid: @@ -326,20 +326,28 @@ def join_inames(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()): new_insns = [ insn.copy( - assignee=subst_map(insn.assignee), - expression=subst_map(insn.expression), forced_iname_deps=subst_forced_iname_deps(insn.forced_iname_deps)) for insn in kernel.instructions] - result = (kernel - .map_expressions(subst_map, exclude_instructions=True) + kernel = (kernel .copy( instructions=new_insns, domains=domch.get_domains_with(new_domain), - applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_map] + applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) - return tag_inames(result, {new_iname: tag}) + from loopy.symbolic import ExpandingSubstitutionMapper + from pymbolic.mapper.substitutor import make_subst_func + subst_map = ExpandingSubstitutionMapper( + kernel.substitutions, kernel.get_var_name_generator(), + make_subst_func(subst_dict), within) + + kernel = subst_map.map_kernel(kernel) + + if tag is not None: + kernel = tag_inames(kernel, {new_iname: tag}) + + return kernel join_dimensions = MovedFunctionDeprecationWrapper(join_inames) diff --git a/loopy/context_matching.py b/loopy/context_matching.py index 51cc8a5fd0360fc6f0d392bd81b58e4ae78b3f85..6a428144ea04aa6f1c5491993a6cc19d19818f55 100644 --- a/loopy/context_matching.py +++ b/loopy/context_matching.py @@ -163,8 +163,11 @@ def parse_stack_match(smatch): lowest < next < ... < highest - where `lowest` is necessarily the bottom of the stack. There is currently - no way to anchor to the top of the stack. + where `lowest` is necessarily the bottom of the stack. `...` matches an + arbitrary number of intervening stack levels. There is currently no way to + match the top of the stack. + + Each of the entries is an identifier match as understood by :func:`parse_id_match`. """ if isinstance(smatch, StackMatchBase): diff --git a/loopy/cse.py b/loopy/cse.py index 28518cb230372279e829f0fc30da582ceb45303f..65d89d8b70b889858d21c2566b3497517d76fd3e 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -385,8 +385,8 @@ class InvocationGatherer(ExpandingIdentityMapper): ExpandingIdentityMapper.__init__(self, kernel.substitutions, kernel.get_var_name_generator()) - from loopy.symbolic import ParametrizedSubstitutor - self.subst_expander = ParametrizedSubstitutor( + from loopy.symbolic import SubstitutionRuleExpander + self.subst_expander = SubstitutionRuleExpander( kernel.substitutions) self.kernel = kernel @@ -449,8 +449,8 @@ class InvocationReplacer(ExpandingIdentityMapper): ExpandingIdentityMapper.__init__(self, kernel.substitutions, kernel.get_var_name_generator()) - from loopy.symbolic import ParametrizedSubstitutor - self.subst_expander = ParametrizedSubstitutor( + from loopy.symbolic import SubstitutionRuleExpander + self.subst_expander = SubstitutionRuleExpander( kernel.substitutions, kernel.get_var_name_generator()) self.kernel = kernel @@ -714,8 +714,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # extra axes made necessary because they don't occur in the arguments extra_storage_axes = sweep_inames_set - expanding_usage_arg_deps - from loopy.symbolic import ParametrizedSubstitutor - submap = ParametrizedSubstitutor(kernel.substitutions) + from loopy.symbolic import SubstitutionRuleExpander + submap = SubstitutionRuleExpander(kernel.substitutions) value_inames = get_dependencies( submap(subst.expression, insn_id=None)) & kernel.all_inames() diff --git a/loopy/subst.py b/loopy/subst.py index eec7d74b983a29118bfb64345ca0afe3334b323e..4110670fed1227c62ddc2c76d7729a5e6fb855fa 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -173,9 +173,9 @@ def extract_subst(kernel, subst_name, template, parameters): def expand_subst(kernel, ctx_match=None): - from loopy.symbolic import ParametrizedSubstitutor + from loopy.symbolic import SubstitutionRuleExpander from loopy.context_matching import parse_stack_match - submap = ParametrizedSubstitutor(kernel.substitutions, + submap = SubstitutionRuleExpander(kernel.substitutions, kernel.get_var_name_generator(), parse_stack_match(ctx_match)) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 015d29a7a4effa8f1965c1f3fcce88681b037447..6ada11f181d9002c7eb55fe7683addcd7998d9a3 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -474,11 +474,25 @@ class ExpandingIdentityMapper(IdentityMapper): substitutions=new_substs, instructions=rename_subst_rules_in_instructions(new_insns, renames)) +class ExpandingSubstitutionMapper(ExpandingIdentityMapper): + def __init__(self, rules, make_unique_var_name, subst_func, within): + ExpandingIdentityMapper.__init__(self, rules, make_unique_var_name) + + self.subst_func = subst_func + self.within = within + + def map_variable(self, expr, expn_state): + result = self.subst_func(expr) + if result is not None or not self.within(expn_state.stack): + return result + else: + return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + # }}} -# {{{ parametrized substitutor +# {{{ substitution rule expander -class ParametrizedSubstitutor(ExpandingIdentityMapper): +class SubstitutionRuleExpander(ExpandingIdentityMapper): def __init__(self, rules, make_unique_var=None, ctx_match=None): ExpandingIdentityMapper.__init__(self, rules, make_unique_var)