From acfd14acec6c8f58e0d6cfbeef60124489fc2627 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 19 Apr 2015 19:12:49 -0500 Subject: [PATCH] Refactor rule-aware term rewriting using SubstitutionRuleMappingContext --- MEMO | 2 +- loopy/__init__.py | 85 +++++++++----- loopy/buffer.py | 19 +-- loopy/kernel/creation.py | 18 +-- loopy/padding.py | 16 +-- loopy/precompute.py | 116 ++++++++++--------- loopy/preprocess.py | 6 +- loopy/subst.py | 40 +++---- loopy/symbolic.py | 244 ++++++++++++++++++++++----------------- 9 files changed, 307 insertions(+), 239 deletions(-) diff --git a/MEMO b/MEMO index d650e7ea4..f4e5c34e4 100644 --- a/MEMO +++ b/MEMO @@ -142,7 +142,7 @@ Dealt with - How can one automatically generate something like microblocks? -> Some sort of axis-adding transform? -- ExpandingIdentityMapper +- RuleAwareIdentityMapper extract_subst -> needs WalkMapper [actually fine as is] padding [DONE] replace make_unique_var_name [DONE] diff --git a/loopy/__init__.py b/loopy/__init__.py index 364540856..194a2c7b1 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -30,7 +30,8 @@ THE SOFTWARE. import islpy as isl from islpy import dim_type -from loopy.symbolic import ExpandingIdentityMapper, ExpandingSubstitutionMapper +from loopy.symbolic import (RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext) from loopy.diagnostic import LoopyError @@ -126,11 +127,10 @@ __all__ = [ # {{{ split inames -class _InameSplitter(ExpandingIdentityMapper): - def __init__(self, kernel, within, +class _InameSplitter(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, within, split_iname, outer_iname, inner_iname, replacement_index): - ExpandingIdentityMapper.__init__(self, - kernel.substitutions, kernel.get_var_name_generator()) + super(_InameSplitter, self).__init__(rule_mapping_context) self.within = within @@ -277,10 +277,13 @@ def split_iname(kernel, split_iname, inner_length, from loopy.context_matching import parse_stack_match within = parse_stack_match(within) - ins = _InameSplitter(kernel, within, + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ins = _InameSplitter(rule_mapping_context, within, split_iname, outer_iname, inner_iname, new_loop_index) kernel = ins.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(kernel) if existing_tag is not None: kernel = tag_inames(kernel, @@ -293,10 +296,10 @@ def split_iname(kernel, split_iname, inner_length, # {{{ join inames -class _InameJoiner(ExpandingSubstitutionMapper): - def __init__(self, kernel, within, subst_func, joined_inames, new_iname): - ExpandingSubstitutionMapper.__init__(self, - kernel.substitutions, kernel.get_var_name_generator(), +class _InameJoiner(RuleAwareSubstitutionMapper): + def __init__(self, rule_mapping_context, within, subst_func, + joined_inames, new_iname): + super(_InameJoiner, self).__init__(rule_mapping_context, subst_func, within) self.joined_inames = set(joined_inames) @@ -425,11 +428,14 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func - ijoin = _InameJoiner(kernel, within, + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ijoin = _InameJoiner(rule_mapping_context, within, make_subst_func(subst_dict), inames, new_iname) - kernel = ijoin.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel( + ijoin.map_kernel(kernel)) if tag is not None: kernel = tag_inames(kernel, {new_iname: tag}) @@ -494,11 +500,10 @@ def tag_inames(kernel, iname_to_tag, force=False): # {{{ duplicate inames -class _InameDuplicator(ExpandingIdentityMapper): - def __init__(self, rules, make_unique_var_name, +class _InameDuplicator(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, old_to_new, within): - super(_InameDuplicator, self).__init__( - rules, make_unique_var_name) + super(_InameDuplicator, self).__init__(rule_mapping_context) self.old_to_new = old_to_new self.old_inames_set = set(six.iterkeys(old_to_new)) @@ -602,11 +607,14 @@ def duplicate_inames(knl, inames, within, new_inames=None, suffix=None, # {{{ change the inames in the code - indup = _InameDuplicator(knl.substitutions, name_gen, + rule_mapping_context = SubstitutionRuleMappingContext( + knl.substitutions, name_gen) + indup = _InameDuplicator(rule_mapping_context, old_to_new=dict(list(zip(inames, new_inames))), within=within) - knl = indup.map_kernel(knl) + knl = rule_mapping_context.finish_kernel( + indup.map_kernel(knl)) # }}} @@ -732,10 +740,13 @@ def link_inames(knl, inames, new_iname, within=None, tag=None): within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func - ijoin = ExpandingSubstitutionMapper(knl.substitutions, var_name_gen, + rule_mapping_context = SubstitutionRuleMappingContext( + knl.substitutions, var_name_gen) + ijoin = RuleAwareSubstitutionMapper(knl.substitutions, var_name_gen, make_subst_func(subst_dict), within) - knl = ijoin.map_kernel(knl) + knl = rule_mapping_context.finish_kernel( + ijoin.map_kernel(knl)) # }}} @@ -1187,10 +1198,10 @@ def tag_data_axes(knl, ary_names, dim_tags): # {{{ split_reduction -class _ReductionSplitter(ExpandingIdentityMapper): - def __init__(self, kernel, within, inames, direction): +class _ReductionSplitter(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, within, inames, direction): super(_ReductionSplitter, self).__init__( - kernel.substitutions, kernel.get_var_name_generator()) + rule_mapping_context) self.within = within self.inames = inames @@ -1230,8 +1241,12 @@ def _split_reduction(kernel, inames, direction, within=None): from loopy.context_matching import parse_stack_match within = parse_stack_match(within) - rsplit = _ReductionSplitter(kernel, within, inames, direction) - return rsplit.map_kernel(kernel) + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + rsplit = _ReductionSplitter(rule_mapping_context, + within, inames, direction) + return rule_mapping_context.finish_kernel( + rsplit.map_kernel(kernel)) def split_reduction_inward(kernel, inames, within=None): @@ -1328,11 +1343,13 @@ def _fix_parameter(kernel, name, value): from loopy.context_matching import parse_stack_match within = parse_stack_match(None) - from loopy.symbolic import ExpandingSubstitutionMapper - esubst_map = ExpandingSubstitutionMapper( - kernel.substitutions, kernel.get_var_name_generator(), - subst_func, within=within) - return (esubst_map.map_kernel(kernel) + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + esubst_map = RuleAwareSubstitutionMapper( + rule_mapping_context, subst_func, within=within) + return ( + rule_mapping_context.finish_kernel( + esubst_map.map_kernel(kernel)) .copy( domains=new_domains, args=new_args, @@ -1633,10 +1650,14 @@ def affine_map_inames(kernel, old_inames, new_inames, equations): var_name_gen = kernel.get_var_name_generator() from pymbolic.mapper.substitutor import make_subst_func - old_to_new = ExpandingSubstitutionMapper(kernel.substitutions, var_name_gen, + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, var_name_gen) + old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context, make_subst_func(subst_dict), within=lambda stack: True) - kernel = (old_to_new.map_kernel(kernel) + kernel = ( + rule_mapping_context.finish_kernel( + old_to_new.map_kernel(kernel)) .copy( applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) diff --git a/loopy/buffer.py b/loopy/buffer.py index a5721890e..c64943f14 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -25,7 +25,8 @@ THE SOFTWARE. from loopy.array_buffer_map import (ArrayToBufferMap, NoOpArrayToBufferMap, AccessDescriptor) -from loopy.symbolic import (get_dependencies, ExpandingIdentityMapper, +from loopy.symbolic import (get_dependencies, + RuleAwareIdentityMapper, SubstitutionRuleMappingContext, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func @@ -34,12 +35,11 @@ from pymbolic import var # {{{ replace array access -class ArrayAccessReplacer(ExpandingIdentityMapper): - def __init__(self, kernel, var_name, within, array_base_map, buf_var): - super(ArrayAccessReplacer, self).__init__( - kernel.substitutions, kernel.get_var_name_generator()) +class ArrayAccessReplacer(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, + var_name, within, array_base_map, buf_var): + super(ArrayAccessReplacer, self).__init__(rule_mapping_context) - self.kernel = kernel self.within = within self.array_base_map = array_base_map @@ -320,8 +320,11 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, # }}} - aar = ArrayAccessReplacer(kernel, var_name, within, abm, buf_var) - kernel = aar.map_kernel(kernel) + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + aar = ArrayAccessReplacer(rule_mapping_context, var_name, + within, abm, buf_var) + kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel)) did_write = False for insn_id in aar.modified_insn_ids: diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index a21d983d7..e99dedbf9 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -459,16 +459,16 @@ class ArgumentGuesser: (assignee_var_name, _), = insn.assignees_and_indices() self.all_written_names.add(assignee_var_name) self.all_names.update(get_dependencies( - self.submap(insn.assignee, insn.id, insn.tags))) + self.submap(insn.assignee))) self.all_names.update(get_dependencies( - self.submap(insn.expression, insn.id, insn.tags))) + self.submap(insn.expression))) def find_index_rank(self, name): irf = IndexRankFinder(name) for insn in self.instructions: insn.with_transformed_expressions( - lambda expr: irf(self.submap(expr, insn.id, insn.tags))) + lambda expr: irf(self.submap(expr))) if not irf.index_ranks: return 0 @@ -859,8 +859,7 @@ def guess_arg_shape_if_requested(kernel, default_order): from loopy.kernel.array import ArrayBase from loopy.symbolic import SubstitutionRuleExpander, AccessRangeMapper - submap = SubstitutionRuleExpander(kernel.substitutions, - kernel.get_var_name_generator()) + submap = SubstitutionRuleExpander(kernel.substitutions) for arg in kernel.args: if isinstance(arg, ArrayBase) and arg.shape is lp.auto: @@ -869,11 +868,12 @@ def guess_arg_shape_if_requested(kernel, default_order): try: for insn in kernel.instructions: if isinstance(insn, lp.ExpressionInstruction): - armap(submap(insn.assignee, insn.id, insn.tags), - kernel.insn_inames(insn)) - armap(submap(insn.expression, insn.id, insn.tags), - kernel.insn_inames(insn)) + armap(submap(insn.assignee), kernel.insn_inames(insn)) + armap(submap(insn.expression), kernel.insn_inames(insn)) except TypeError as e: + from traceback import print_exc + print_exc() + from loopy.diagnostic import LoopyError raise LoopyError( "Failed to (automatically, as requested) find " diff --git a/loopy/padding.py b/loopy/padding.py index 8f747d413..4951c6854 100644 --- a/loopy/padding.py +++ b/loopy/padding.py @@ -25,12 +25,12 @@ THE SOFTWARE. """ -from loopy.symbolic import ExpandingIdentityMapper +from loopy.symbolic import RuleAwareIdentityMapper, SubstitutionRuleMappingContext -class ArgAxisSplitHelper(ExpandingIdentityMapper): - def __init__(self, rules, var_name_gen, arg_names, handler): - ExpandingIdentityMapper.__init__(self, rules, var_name_gen) +class ArgAxisSplitHelper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, arg_names, handler): + super(ArgAxisSplitHelper, self).__init__(rule_mapping_context) self.arg_names = arg_names self.handler = handler @@ -38,7 +38,7 @@ class ArgAxisSplitHelper(ExpandingIdentityMapper): if expr.aggregate.name in self.arg_names: return self.handler(expr) else: - return ExpandingIdentityMapper.map_subscript(self, expr, expn_state) + return super(ArgAxisSplitHelper, self).map_subscript(expr, expn_state) def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True): @@ -205,9 +205,11 @@ def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True): return expr.aggregate.index(tuple(idx)) - aash = ArgAxisSplitHelper(kernel.substitutions, var_name_gen, + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, var_name_gen) + aash = ArgAxisSplitHelper(rule_mapping_context, set(six.iterkeys(arg_to_rest)), split_access_axis) - kernel = aash.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel)) kernel = kernel.copy(args=new_args) diff --git a/loopy/precompute.py b/loopy/precompute.py index d37b53d7b..e516dde13 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -25,8 +25,9 @@ THE SOFTWARE. """ -from loopy.symbolic import (get_dependencies, ExpandingSubstitutionMapper, - ExpandingIdentityMapper) +from loopy.symbolic import (get_dependencies, + RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext) from pymbolic.mapper.substitutor import make_subst_func import numpy as np @@ -58,10 +59,9 @@ def storage_axis_exprs(storage_axis_sources, args): # {{{ gather rule invocations -class RuleInvocationGatherer(ExpandingIdentityMapper): - def __init__(self, kernel, subst_name, subst_tag, within): - ExpandingIdentityMapper.__init__(self, - kernel.substitutions, kernel.get_var_name_generator()) +class RuleInvocationGatherer(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): + super(RuleInvocationGatherer, self).__init__(rule_mapping_context) from loopy.symbolic import SubstitutionRuleExpander self.subst_expander = SubstitutionRuleExpander( @@ -83,18 +83,17 @@ class RuleInvocationGatherer(ExpandingIdentityMapper): process_me = process_me and self.within(expn_state.stack) if not process_me: - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) + return super(RuleInvocationGatherer, self).map_substitution( + name, tag, arguments, expn_state) - rule = self.old_subst_rules[name] + rule = self.rule_mapping_context.old_subst_rules[name] arg_context = self.make_new_arg_context( name, rule.arguments, arguments, expn_state.arg_context) arg_deps = set() for arg_val in six.itervalues(arg_context): arg_deps = (arg_deps - | get_dependencies(self.subst_expander( - arg_val, insn_id=None, insn_tags=None))) + | get_dependencies(self.subst_expander(arg_val))) # FIXME: This is too strict--and the footprint machinery # needs to be taught how to deal with locally constant @@ -109,8 +108,8 @@ class RuleInvocationGatherer(ExpandingIdentityMapper): ", ".join(arg_deps - self.kernel.all_inames()), )) - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) + return super(RuleInvocationGatherer, self).map_substitution( + name, tag, arguments, expn_state) args = [arg_context[arg_name] for arg_name in rule.arguments] @@ -127,16 +126,14 @@ class RuleInvocationGatherer(ExpandingIdentityMapper): # {{{ replace rule invocation -class RuleInvocationReplacer(ExpandingIdentityMapper): - def __init__(self, kernel, subst_name, subst_tag, within, +class RuleInvocationReplacer(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, subst_name, subst_tag, within, access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, non1_storage_axis_names, target_var_name): - ExpandingIdentityMapper.__init__(self, - kernel.substitutions, kernel.get_var_name_generator()) + super(RuleInvocationReplacer, self).__init__(rule_mapping_context) - self.kernel = kernel self.subst_name = subst_name self.subst_tag = subst_tag self.within = within @@ -155,12 +152,12 @@ class RuleInvocationReplacer(ExpandingIdentityMapper): name == self.subst_name and self.within(expn_state.stack) and (self.subst_tag is None or self.subst_tag == tag)): - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) + return super(RuleInvocationReplacer, self).map_substitution( + name, tag, arguments, expn_state) # {{{ check if in footprint - rule = self.old_subst_rules[name] + rule = self.rule_mapping_context.old_subst_rules[name] arg_context = self.make_new_arg_context( name, rule.arguments, arguments, expn_state.arg_context) args = [arg_context[arg_name] for arg_name in rule.arguments] @@ -170,8 +167,8 @@ class RuleInvocationReplacer(ExpandingIdentityMapper): self.storage_axis_sources, args)) if not self.array_base_map.is_access_descriptor_in_footprint(accdesc): - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) + return super(RuleInvocationReplacer, self).map_substitution( + name, tag, arguments, expn_state) # }}} @@ -202,10 +199,9 @@ class RuleInvocationReplacer(ExpandingIdentityMapper): if stor_subscript: new_outer_expr = new_outer_expr.index(tuple(stor_subscript)) - # Can't possibly be nested, but recurse anyway to - # make sure substitution rules referenced below here - # do not get thrown away. - self.rec(rule.expression, expn_state.copy(arg_context={})) + # Can't possibly be nested, and no need to traverse + # further as compute expression has already been seen + # by rule_mapping_context. return new_outer_expr @@ -369,7 +365,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # {{{ gather up invocations in kernel code, finish access_descriptors if not footprint_generators: - invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within) + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + invg = RuleInvocationGatherer( + rule_mapping_context, kernel, subst_name, subst_tag, within) + del rule_mapping_context import loopy as lp for insn in kernel.instructions: @@ -405,7 +405,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, submap = SubstitutionRuleExpander(kernel.substitutions) value_inames = get_dependencies( - submap(subst.expression, insn_id=None, insn_tags=None) + submap(subst.expression) ) & kernel.all_inames() if value_inames - expanding_usage_arg_deps < extra_storage_axes: raise RuntimeError("unreferenced sweep inames specified: " @@ -513,6 +513,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, non1_storage_axis_names = [] abm = NoOpArrayToBufferMap() + kernel = kernel.copy(domains=new_kernel_domains) + # {{{ set up compute insn target_var_name = var_name_gen(based_on=c_subst_name) @@ -522,24 +524,54 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, assignee = assignee.index( tuple(var(iname) for iname in non1_storage_axis_names)) + # {{{ process substitutions on compute instruction + + storage_axis_subst_dict = {} + + for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): + if arg_name in non1_storage_axis_names: + arg = var(arg_name) + else: + arg = 0 + + storage_axis_subst_dict[ + prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + + from loopy.context_matching import AllStackMatch + expr_subst_map = RuleAwareSubstitutionMapper( + rule_mapping_context, + make_subst_func(storage_axis_subst_dict), + within=AllStackMatch()) + + compute_expression = expr_subst_map(subst.expression, None, None) + + # }}} + from loopy.kernel.data import ExpressionInstruction compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name) compute_insn = ExpressionInstruction( id=compute_insn_id, assignee=assignee, - expression=subst.expression) + expression=compute_expression) # }}} # {{{ substitute rule into expressions in kernel (if within footprint) - invr = RuleInvocationReplacer(kernel, subst_name, subst_tag, within, + invr = RuleInvocationReplacer(rule_mapping_context, + subst_name, subst_tag, within, access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, target_var_name) kernel = invr.map_kernel(kernel) + kernel = kernel.copy( + instructions=[compute_insn] + kernel.instructions) + kernel = rule_mapping_context.finish_kernel(kernel) # }}} @@ -566,31 +598,9 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, new_temporary_variables[target_var_name] = temp_var - # }}} - kernel = kernel.copy( - domains=new_kernel_domains, - instructions=[compute_insn] + kernel.instructions, temporary_variables=new_temporary_variables) - # {{{ process substitutions on compute instruction - - storage_axis_subst_dict = {} - - for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): - if arg_name in non1_storage_axis_names: - arg = var(arg_name) - else: - arg = 0 - - storage_axis_subst_dict[prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi - - expr_subst_map = ExpandingSubstitutionMapper( - kernel.substitutions, kernel.get_var_name_generator(), - make_subst_func(storage_axis_subst_dict), - parse_stack_match("... < "+compute_insn_id)) - kernel = expr_subst_map.map_kernel(kernel) - # }}} from loopy import tag_inames diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 145272907..082b18895 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -88,8 +88,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander): if not isinstance(writer_insn, lp.ExpressionInstruction): continue - expr = subst_expander(writer_insn.expression, - insn_id=writer_insn_id, insn_tags=writer_insn.tags) + expr = subst_expander(writer_insn.expression) try: debug(" via expr %s" % expr) @@ -171,8 +170,7 @@ def infer_unknown_types(kernel, expect_completion=False): ])) from loopy.symbolic import SubstitutionRuleExpander - subst_expander = SubstitutionRuleExpander(kernel.substitutions, - kernel.get_var_name_generator()) + subst_expander = SubstitutionRuleExpander(kernel.substitutions) # {{{ work on type inference queue diff --git a/loopy/subst.py b/loopy/subst.py index d1a643f30..412623c90 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -27,7 +27,7 @@ THE SOFTWARE. from loopy.symbolic import ( get_dependencies, SubstitutionMapper, - ExpandingIdentityMapper) + RuleAwareIdentityMapper, SubstitutionRuleMappingContext) from loopy.diagnostic import LoopyError from pymbolic.mapper.substitutor import make_subst_func @@ -200,15 +200,13 @@ def extract_subst(kernel, subst_name, template, parameters=()): # {{{ temporary_to_subst -class TemporaryToSubstChanger(ExpandingIdentityMapper): - def __init__(self, kernel, temp_name, definition_insn_ids, +class TemporaryToSubstChanger(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, temp_name, definition_insn_ids, usage_to_definition, within): - self.var_name_gen = kernel.get_var_name_generator() + self.var_name_gen = rule_mapping_context.make_unique_var_name - super(TemporaryToSubstChanger, self).__init__( - kernel.substitutions, self.var_name_gen) + super(TemporaryToSubstChanger, self).__init__(rule_mapping_context) - self.kernel = kernel self.temp_name = temp_name self.definition_insn_ids = definition_insn_ids self.usage_to_definition = usage_to_definition @@ -349,10 +347,13 @@ def temporary_to_subst(kernel, temp_name, within=None): from loopy.context_matching import parse_stack_match within = parse_stack_match(within) - tts = TemporaryToSubstChanger(kernel, temp_name, definition_insn_ids, + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + tts = TemporaryToSubstChanger(rule_mapping_context, + temp_name, definition_insn_ids, usage_to_definition, within) - kernel = tts.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) from loopy.kernel.data import SubstitutionRule @@ -412,19 +413,18 @@ def temporary_to_subst(kernel, temp_name, within=None): # }}} -def expand_subst(kernel, ctx_match=None): +def expand_subst(kernel, within=None): logger.debug("%s: expand subst" % kernel.name) - from loopy.symbolic import SubstitutionRuleExpander + from loopy.symbolic import RuleAwareSubstitutionRuleExpander from loopy.context_matching import parse_stack_match - submap = SubstitutionRuleExpander(kernel.substitutions, - kernel.get_var_name_generator(), - parse_stack_match(ctx_match)) - - kernel = submap.map_kernel(kernel) - if ctx_match is None: - return kernel.copy(substitutions={}) - else: - return kernel + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + submap = RuleAwareSubstitutionRuleExpander( + rule_mapping_context, + kernel.substitutions, + parse_stack_match(within)) + + return rule_mapping_context.finish_kernel(submap.map_kernel(kernel)) # vim: foldmethod=marker diff --git a/loopy/symbolic.py b/loopy/symbolic.py index a17c4cdd5..8897f3593 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1,11 +1,8 @@ """Pymbolic mappers for loopy.""" -from __future__ import division -from __future__ import absolute_import +from __future__ import division, absolute_import import six -from six.moves import range -from six.moves import zip -from functools import reduce +from six.moves import range, zip, reduce __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -197,6 +194,41 @@ class DependencyMapper(DependencyMapperBase): map_linear_subscript = DependencyMapperBase.map_subscript + +class SubstitutionRuleExpander(IdentityMapper): + def __init__(self, rules): + self.rules = rules + + def map_variable(self, expr): + if expr.name in self.rules: + return self.map_substitution(expr.name, self.rules[expr.name], ()) + else: + return super(SubstitutionRuleExpander, self).map_variable(expr) + + def map_call(self, expr): + if expr.function.name in self.rules: + return self.map_substitution( + expr.function.name, + self.rules[expr.function.name], + expr.parameters) + else: + return super(SubstitutionRuleExpander, self).map_call(expr) + + def map_substitution(self, name, rule, arguments): + if len(rule.arguments) != len(arguments): + from loopy.diagnostic import LoopyError + raise LoopyError("number of arguments to '%s' does not match " + "definition" % name) + + from pymbolic.mapper.substitutor import make_subst_func + submap = SubstitutionMapper( + make_subst_func( + dict(zip(rule.arguments, arguments)))) + + expr = submap(rule.expression) + + return self.rec(expr) + # }}} @@ -333,7 +365,7 @@ def get_dependencies(expr): return frozenset(dep.name for dep in dep_mapper(expr)) -# {{{ identity mapper that expands subst rules on the fly +# {{{ rule-aware mappers def parse_tagged_name(expr): if isinstance(expr, TaggedVariable): @@ -409,14 +441,7 @@ def rename_subst_rules_in_instructions(insns, renames): for insn in insns] -class ExpandingIdentityMapper(IdentityMapper): - """Note: the third argument dragged around by this mapper is the - current :class:`ExpansionState`. - - Subclasses of this must be careful to not touch identifiers that - are in :attr:`ExpansionState.arg_context`. - """ - +class SubstitutionRuleMappingContext(object): def __init__(self, old_subst_rules, make_unique_var_name): self.old_subst_rules = old_subst_rules self.make_unique_var_name = make_unique_var_name @@ -445,9 +470,84 @@ class ExpandingIdentityMapper(IdentityMapper): self.subst_rule_use_count[key] = self.subst_rule_use_count.get(key, 0) + 1 return new_name + def _get_new_substitutions_and_renames(self): + """This makes a new dictionary of substitutions from the ones + encountered in mapping all the encountered expressions. + It tries hard to keep substitution names the same--i.e. + if all derivative versions of a substitution rule ended + up with the same mapped version, then this version should + retain the name that the substitution rule had previously. + Unfortunately, this can't be done in a single pass, and so + the routine returns an additional dictionary *subst_renames* + of renamings to be performed on the processed expressions. + + The returned substitutions already have the rename applied + to them. + + :returns: (new_substitutions, subst_renames) + """ + + from loopy.kernel.data import SubstitutionRule + + orig_name_histogram = {} + for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): + if self.subst_rule_use_count.get(key, 0): + orig_name_histogram[orig_name] = \ + orig_name_histogram.get(orig_name, 0) + 1 + + result = {} + renames = {} + + for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): + args, body = key + + if self.subst_rule_use_count.get(key, 0): + if orig_name_histogram[orig_name] == 1 and name != orig_name: + renames[name] = orig_name + name = orig_name + + result[name] = SubstitutionRule( + name=name, + arguments=args, + expression=body) + + # {{{ perform renames on new substitutions + + subst_renamer = SubstitutionRuleRenamer(renames) + + renamed_result = {} + for name, rule in six.iteritems(result): + renamed_result[name] = rule.copy( + expression=subst_renamer(rule.expression)) + + # }}} + + return renamed_result, renames + + def finish_kernel(self, kernel): + new_substs, renames = self._get_new_substitutions_and_renames() + + new_insns = rename_subst_rules_in_instructions(kernel.instructions, renames) + + return kernel.copy( + substitutions=new_substs, + instructions=new_insns) + + +class RuleAwareIdentityMapper(IdentityMapper): + """Note: the third argument dragged around by this mapper is the + current :class:`ExpansionState`. + + Subclasses of this must be careful to not touch identifiers that + are in :attr:`ExpansionState.arg_context`. + """ + + def __init__(self, rule_mapping_context): + self.rule_mapping_context = rule_mapping_context + def map_variable(self, expr, expn_state): name, tag = parse_tagged_name(expr) - if name not in self.old_subst_rules: + if name not in self.rule_mapping_context.old_subst_rules: return IdentityMapper.map_variable(self, expr, expn_state) else: return self.map_substitution(name, tag, (), expn_state) @@ -458,8 +558,8 @@ class ExpandingIdentityMapper(IdentityMapper): name, tag = parse_tagged_name(expr.function) - if name not in self.old_subst_rules: - return IdentityMapper.map_call(self, expr, expn_state) + if name not in self.rule_mapping_context.old_subst_rules: + return super(RuleAwareIdentityMapper, self).map_call(expr, expn_state) else: return self.map_substitution(name, tag, expr.parameters, expn_state) @@ -476,7 +576,7 @@ class ExpandingIdentityMapper(IdentityMapper): for formal_arg_name, arg_value in zip(arg_names, arguments)) def map_substitution(self, name, tag, arguments, expn_state): - rule = self.old_subst_rules[name] + rule = self.rule_mapping_context.old_subst_rules[name] rec_arguments = self.rec(arguments, expn_state) @@ -492,7 +592,8 @@ class ExpandingIdentityMapper(IdentityMapper): result = self.rec(rule.expression, new_expn_state) - new_name = self.register_subst_rule(name, rule.arguments, result) + new_name = self.rule_mapping_context.register_subst_rule( + name, rule.arguments, result) if tag is None: sym = Variable(new_name) @@ -513,60 +614,6 @@ class ExpandingIdentityMapper(IdentityMapper): return IdentityMapper.__call__(self, expr, ExpansionState( stack=stack, arg_context={})) - def _get_new_substitutions_and_renames(self): - """This makes a new dictionary of substitutions from the ones - encountered in mapping all the encountered expressions. - It tries hard to keep substitution names the same--i.e. - if all derivative versions of a substitution rule ended - up with the same mapped version, then this version should - retain the name that the substitution rule had previously. - Unfortunately, this can't be done in a single pass, and so - the routine returns an additional dictionary *subst_renames* - of renamings to be performed on the processed expressions. - - The returned substitutions already have the rename applied - to them. - - :returns: (new_substitutions, subst_renames) - """ - - from loopy.kernel.data import SubstitutionRule - - orig_name_histogram = {} - for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): - if self.subst_rule_use_count.get(key, 0): - orig_name_histogram[orig_name] = \ - orig_name_histogram.get(orig_name, 0) + 1 - - result = {} - renames = {} - - for key, (name, orig_name) in six.iteritems(self.subst_rule_registry): - args, body = key - - if self.subst_rule_use_count.get(key, 0): - if orig_name_histogram[orig_name] == 1 and name != orig_name: - renames[name] = orig_name - name = orig_name - - result[name] = SubstitutionRule( - name=name, - arguments=args, - expression=body) - - # {{{ perform renames on new substitutions - - subst_renamer = SubstitutionRuleRenamer(renames) - - renamed_result = {} - for name, rule in six.iteritems(result): - renamed_result[name] = rule.copy( - expression=subst_renamer(rule.expression)) - - # }}} - - return renamed_result, renames - def map_instruction(self, insn): return insn @@ -575,24 +622,16 @@ class ExpandingIdentityMapper(IdentityMapper): # While subst rules are not allowed in assignees, the mapper # may perform tasks entirely unrelated to subst rules, so # we must map assignees, too. - - insn.with_transformed_expressions(self, insn.id, insn.tags) + self.map_instruction( + insn.with_transformed_expressions(self, insn.id, insn.tags)) for insn in kernel.instructions] - new_substs, renames = self._get_new_substitutions_and_renames() + return kernel.copy(instructions=new_insns) - new_insns = [self.map_instruction(insn) - for insn in rename_subst_rules_in_instructions( - new_insns, renames)] - - return kernel.copy( - substitutions=new_substs, - instructions=new_insns) - -class ExpandingSubstitutionMapper(ExpandingIdentityMapper): - def __init__(self, rules, make_unique_var_name, subst_func, within): - ExpandingIdentityMapper.__init__(self, rules, make_unique_var_name) +class RuleAwareSubstitutionMapper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, subst_func, within): + super(RuleAwareSubstitutionMapper, self).__init__(rule_mapping_context) self.subst_func = subst_func self.within = within @@ -600,28 +639,23 @@ class ExpandingSubstitutionMapper(ExpandingIdentityMapper): def map_variable(self, expr, expn_state): if (expr.name in expn_state.arg_context or not self.within(expn_state.stack)): - return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + return super(RuleAwareSubstitutionMapper, self).map_variable( + expr, expn_state) result = self.subst_func(expr) if result is not None: return result else: - return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + return super(RuleAwareSubstitutionMapper, self).map_variable( + expr, expn_state) -# }}} +class RuleAwareSubstitutionRuleExpander(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, rules, within): + super(RuleAwareSubstitutionRuleExpander, self).__init__(rule_mapping_context) -# {{{ substitution rule expander - -class SubstitutionRuleExpander(ExpandingIdentityMapper): - def __init__(self, rules, make_unique_var=None, ctx_match=None): - ExpandingIdentityMapper.__init__(self, rules, make_unique_var) - - if ctx_match is None: - from loopy.context_matching import AllStackMatch - ctx_match = AllStackMatch() - - self.ctx_match = ctx_match + self.rules = rules + self.within = within def map_substitution(self, name, tag, arguments, expn_state): if tag is None: @@ -631,9 +665,9 @@ class SubstitutionRuleExpander(ExpandingIdentityMapper): new_stack = expn_state.stack + ((name, tags),) - if self.ctx_match(new_stack): + if self.within(new_stack): # expand - rule = self.old_subst_rules[name] + rule = self.rules[name] new_expn_state = expn_state.copy( stack=new_stack, @@ -651,8 +685,8 @@ class SubstitutionRuleExpander(ExpandingIdentityMapper): else: # do not expand - return ExpandingIdentityMapper.map_substitution( - self, name, tag, arguments, expn_state) + return super(RuleAwareSubstitutionRuleExpander, self).map_substitution( + name, tag, arguments, expn_state) # }}} -- GitLab