diff --git a/loopy/cse.py b/loopy/cse.py index fb26eec59f0afa36e3fe88e39807856a3f7fab60..3246a34e6267361f1e3d259eb18bbb05521a3fe8 100644 --- a/loopy/cse.py +++ b/loopy/cse.py @@ -285,6 +285,12 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are eliminated. """ + + from loopy.symbolic import SubstitutionCallbackMapper + + c_subst_name = subst_name.replace(".", "_") + subst_name, subst_instance = SubstitutionCallbackMapper.parse_filter(subst_name) + from loopy.kernel import parse_tag default_tag = parse_tag(default_tag) @@ -295,7 +301,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], invocation_descriptors = [] - def gather_substs(expr, name, args, rec): + def gather_substs(expr, name, instance, args, rec): if len(args) != len(subst.arguments): raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" % (subst_name, len(args), len(subst.arguments), )) @@ -310,7 +316,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], return expr from loopy.symbolic import SubstitutionCallbackMapper - scm = SubstitutionCallbackMapper([subst_name], gather_substs) + scm = SubstitutionCallbackMapper([(subst_name, subst_instance)], gather_substs) from loopy.symbolic import ParametrizedSubstitutor rules_except_mine = kernel.substitutions.copy() @@ -324,6 +330,9 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], scm(subst_expander(insn.expression)) + if not invocation_descriptors: + raise RuntimeError("no invocations of '%s' found" % subst_name) + # }}} # {{{ deal with argument names as sweep axes @@ -396,7 +405,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], name = old_name = subst.arguments[saxis] else: old_name = saxis - name = "%s_%s" % (subst_name, old_name) + name = "%s_%s" % (c_subst_name, old_name) if new_storage_axis_names is not None and i < len(new_storage_axis_names): name = new_storage_axis_names[i] @@ -460,7 +469,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], # {{{ set up temp variable - target_var_name = kernel.make_unique_var_name(based_on=subst_name, + target_var_name = kernel.make_unique_var_name(based_on=c_subst_name, extra_used_vars=newly_created_var_names) from loopy.kernel import TemporaryVariable @@ -497,7 +506,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], from loopy.kernel import Instruction compute_insn = Instruction( - id=kernel.make_unique_instruction_id(based_on=subst_name), + id=kernel.make_unique_instruction_id(based_on=c_subst_name), assignee=assignee, expression=compute_expr) @@ -505,7 +514,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], # {{{ substitute rule into expressions in kernel - def do_substs(expr, name, args, rec): + def do_substs(expr, name, instance, args, rec): if len(args) != len(subst.arguments): raise ValueError("invocation of '%s' with too few arguments" % name) @@ -535,7 +544,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], new_insns = [compute_insn] - sub_map = SubstitutionCallbackMapper([subst_name], do_substs) + sub_map = SubstitutionCallbackMapper([(subst_name, subst_instance)], do_substs) for insn in kernel.instructions: new_insn = insn.copy(expression=sub_map(insn.expression)) new_insns.append(new_insn) @@ -543,7 +552,11 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[], new_substs = dict( (s.name, s.copy(expression=sub_map(s.expression))) for s in kernel.substitutions.itervalues() - if s.name != subst_name) + + # leave rule be if instance was specified + # (even if it might end up unused--FIXME) + if subst_instance is not None + or s.name != subst_name) # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index c65f6e7ac202ecebece95b386e89dcc8561f973b..44650bac88f76aae8a5908755c76ace3e3749bbd 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -579,67 +579,93 @@ class IndexVariableFinder(CombineMapper): # }}} -# {{{ parametrized substitutor - -class ParametrizedSubstitutor(IdentityMapper): - def __init__(self, rules): - self.rules = rules +# {{{ substitution callback mapper - def map_variable(self, expr): - if expr.name not in self.rules: - return IdentityMapper.map_variable(self, expr) +class SubstitutionCallbackMapper(IdentityMapper): + @staticmethod + def parse_filter(filt): + if not isinstance(filt, tuple): + dotted_components = filt.split(".") + if len(dotted_components) == 1: + return (dotted_components[0], None) + elif len(dotted_components) == 2: + return tuple(dotted_components) + else: + raise RuntimeError("too many dotted components in '%s'" % filt) + else: + if len(filt) != 2: + raise RuntimeError("substitution name filters " + "may have at most two components") - rule = self.rules[expr.name] - if rule.arguments: - raise RuntimeError("CSE '%s' must be invoked with %d arguments" - % (expr.name, len(rule.arguments))) + return filt - return self.rec(rule.expression) + def __init__(self, names_filter, func): + if names_filter is not None: + new_names_filter = [] + for filt in names_filter: + new_names_filter.append(self.parse_filter(filt)) - def map_call(self, expr): - from pymbolic.primitives import Variable - if (not isinstance(expr.function, Variable) - or expr.function.name not in self.rules): - return IdentityMapper.map_variable(self, expr) - - rule_name = expr.function.name - rule = self.rules[rule_name] - if len(rule.arguments) != len(expr.parameters): - raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" - % (rule_name, len(expr.parameters), len(rule.arguments), )) - - from pymbolic.mapper.substitutor import make_subst_func - subst_map = SubstitutionMapper(make_subst_func( - dict(zip(rule.arguments, expr.parameters)))) + self.names_filter = new_names_filter + else: + self.names_filter = names_filter - return self.rec(subst_map(rule.expression)) + self.func = func -# }}} + def parse_name(self, expr): + from pymbolic.primitives import Variable, Lookup + if isinstance(expr, Variable): + e_name, e_instance = expr.name, None + elif isinstance(expr, Lookup): + if not isinstance(expr.aggregate, Variable): + return None + e_name, e_instance = expr.aggregate.name, expr.name + else: + return None -# {{{ substitution callback mapper + if self.names_filter is not None: + for filt_name, filt_instance in self.names_filter: + if e_name == filt_name: + if filt_instance is None or filt_instance == e_instance: + return e_name, e_instance + else: + return e_name, e_instance -class SubstitutionCallbackMapper(IdentityMapper): - def __init__(self, names, func): - self.names = names - self.func = func + return None def map_variable(self, expr): - if expr.name not in self.names: + parsed_name = self.parse_name(expr) + if parsed_name is None: return IdentityMapper.map_variable(self, expr) - result = self.func(expr, expr.name, (), self.rec) + name, instance = parsed_name + + result = self.func(expr, name, instance, (), self.rec) if result is None: return IdentityMapper.map_variable(self, expr) else: return result + def map_lookup(self, expr): + parsed_name = self.parse_name(expr) + if parsed_name is None: + return IdentityMapper.map_lookup(self, expr) + + name, instance = parsed_name + + result = self.func(expr, name, instance, (), self.rec) + if result is None: + return IdentityMapper.map_lookup(self, expr) + else: + return result + def map_call(self, expr): - from pymbolic.primitives import Variable - if (not isinstance(expr.function, Variable) - or expr.function.name not in self.names): - return IdentityMapper.map_variable(self, expr) + parsed_name = self.parse_name(expr.function) + if parsed_name is None: + return IdentityMapper.map_call(self, expr) + + name, instance = parsed_name - result = self.func(expr, expr.function.name, expr.parameters, self.rec) + result = self.func(expr, name, instance, expr.parameters, self.rec) if result is None: return IdentityMapper.map_call(self, expr) else: @@ -647,6 +673,30 @@ class SubstitutionCallbackMapper(IdentityMapper): # }}} +# {{{ parametrized substitutor + +class ParametrizedSubstitutor(object): + def __init__(self, rules): + self.rules = rules + + def __call__(self, expr): + def expand_if_known(expr, name, instance, args, rec): + rule = self.rules[name] + if len(rule.arguments) != len(args): + raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)" + % (name, len(args), len(rule.arguments), )) + + from pymbolic.mapper.substitutor import make_subst_func + subst_map = SubstitutionMapper(make_subst_func( + dict(zip(rule.arguments, args)))) + + return rec(subst_map(rule.expression)) + + scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known) + return scm(expr) + +# }}} + # {{{ wildcard -> unique variable mapper class WildcardToUniqueVariableMapper(IdentityMapper):