From bf38b67c7ce1ca2730814d1e9d17d40b46008414 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 20 Nov 2011 13:42:23 -0500
Subject: [PATCH] First shot at instance-based subst rule precomputation.

---
 loopy/cse.py      |  29 +++++++---
 loopy/symbolic.py | 134 +++++++++++++++++++++++++++++++---------------
 2 files changed, 113 insertions(+), 50 deletions(-)

diff --git a/loopy/cse.py b/loopy/cse.py
index fb26eec59..3246a34e6 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -285,6 +285,12 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are
     eliminated.
     """
+
+    from loopy.symbolic import SubstitutionCallbackMapper
+
+    c_subst_name = subst_name.replace(".", "_")
+    subst_name, subst_instance = SubstitutionCallbackMapper.parse_filter(subst_name)
+
     from loopy.kernel import parse_tag
     default_tag = parse_tag(default_tag)
 
@@ -295,7 +301,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     invocation_descriptors = []
 
-    def gather_substs(expr, name, args, rec):
+    def gather_substs(expr, name, instance, args, rec):
         if len(args) != len(subst.arguments):
             raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
                     % (subst_name, len(args), len(subst.arguments), ))
@@ -310,7 +316,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
         return expr
 
     from loopy.symbolic import SubstitutionCallbackMapper
-    scm = SubstitutionCallbackMapper([subst_name], gather_substs)
+    scm = SubstitutionCallbackMapper([(subst_name, subst_instance)], gather_substs)
 
     from loopy.symbolic import ParametrizedSubstitutor
     rules_except_mine = kernel.substitutions.copy()
@@ -324,6 +330,9 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
         scm(subst_expander(insn.expression))
 
+    if not invocation_descriptors:
+        raise RuntimeError("no invocations of '%s' found" % subst_name)
+
     # }}}
 
     # {{{ deal with argument names as sweep axes
@@ -396,7 +405,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
             name = old_name = subst.arguments[saxis]
         else:
             old_name = saxis
-            name = "%s_%s" % (subst_name, old_name)
+            name = "%s_%s" % (c_subst_name, old_name)
 
         if new_storage_axis_names is not None and i < len(new_storage_axis_names):
             name = new_storage_axis_names[i]
@@ -460,7 +469,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     # {{{ set up temp variable
 
-    target_var_name = kernel.make_unique_var_name(based_on=subst_name,
+    target_var_name = kernel.make_unique_var_name(based_on=c_subst_name,
             extra_used_vars=newly_created_var_names)
 
     from loopy.kernel import TemporaryVariable
@@ -497,7 +506,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     from loopy.kernel import Instruction
     compute_insn = Instruction(
-            id=kernel.make_unique_instruction_id(based_on=subst_name),
+            id=kernel.make_unique_instruction_id(based_on=c_subst_name),
             assignee=assignee,
             expression=compute_expr)
 
@@ -505,7 +514,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     # {{{ substitute rule into expressions in kernel
 
-    def do_substs(expr, name, args, rec):
+    def do_substs(expr, name, instance, args, rec):
         if len(args) != len(subst.arguments):
             raise ValueError("invocation of '%s' with too few arguments"
                     % name)
@@ -535,7 +544,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     new_insns = [compute_insn]
 
-    sub_map = SubstitutionCallbackMapper([subst_name], do_substs)
+    sub_map = SubstitutionCallbackMapper([(subst_name, subst_instance)], do_substs)
     for insn in kernel.instructions:
         new_insn = insn.copy(expression=sub_map(insn.expression))
         new_insns.append(new_insn)
@@ -543,7 +552,11 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     new_substs = dict(
             (s.name, s.copy(expression=sub_map(s.expression)))
             for s in kernel.substitutions.itervalues()
-            if s.name != subst_name)
+
+            # leave rule be if instance was specified
+            # (even if it might end up unused--FIXME)
+            if subst_instance is not None 
+            or s.name != subst_name)
 
     # }}}
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index c65f6e7ac..44650bac8 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -579,67 +579,93 @@ class IndexVariableFinder(CombineMapper):
 
 # }}}
 
-# {{{ parametrized substitutor
-
-class ParametrizedSubstitutor(IdentityMapper):
-    def __init__(self, rules):
-        self.rules = rules
+# {{{ substitution callback mapper
 
-    def map_variable(self, expr):
-        if expr.name not in self.rules:
-            return IdentityMapper.map_variable(self, expr)
+class SubstitutionCallbackMapper(IdentityMapper):
+    @staticmethod
+    def parse_filter(filt):
+        if not isinstance(filt, tuple):
+            dotted_components = filt.split(".")
+            if len(dotted_components) == 1:
+                return (dotted_components[0], None)
+            elif len(dotted_components) == 2:
+                return tuple(dotted_components)
+            else:
+                raise RuntimeError("too many dotted components in '%s'" % filt)
+        else:
+            if len(filt) != 2:
+                raise RuntimeError("substitution name filters "
+                        "may have at most two components")
 
-        rule = self.rules[expr.name]
-        if rule.arguments:
-            raise RuntimeError("CSE '%s' must be invoked with %d arguments"
-                    % (expr.name, len(rule.arguments)))
+            return filt
 
-        return self.rec(rule.expression)
+    def __init__(self, names_filter, func):
+        if names_filter is not None:
+            new_names_filter = []
+            for filt in names_filter:
+                new_names_filter.append(self.parse_filter(filt))
 
-    def map_call(self, expr):
-        from pymbolic.primitives import Variable
-        if (not isinstance(expr.function, Variable)
-                or expr.function.name not in self.rules):
-            return IdentityMapper.map_variable(self, expr)
-
-        rule_name = expr.function.name
-        rule = self.rules[rule_name]
-        if len(rule.arguments) != len(expr.parameters):
-            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
-                    % (rule_name, len(expr.parameters), len(rule.arguments), ))
-
-        from pymbolic.mapper.substitutor import make_subst_func
-        subst_map = SubstitutionMapper(make_subst_func(
-            dict(zip(rule.arguments, expr.parameters))))
+            self.names_filter = new_names_filter
+        else:
+            self.names_filter = names_filter
 
-        return self.rec(subst_map(rule.expression))
+        self.func = func
 
-# }}}
+    def parse_name(self, expr):
+        from pymbolic.primitives import Variable, Lookup
+        if isinstance(expr, Variable):
+            e_name, e_instance = expr.name, None
+        elif isinstance(expr, Lookup):
+            if not isinstance(expr.aggregate, Variable):
+                return None
+            e_name, e_instance = expr.aggregate.name, expr.name
+        else:
+            return None
 
-# {{{ substitution callback mapper
+        if self.names_filter is not None:
+            for filt_name, filt_instance in self.names_filter:
+                if e_name == filt_name:
+                    if filt_instance is None or filt_instance == e_instance:
+                        return e_name, e_instance
+        else:
+            return e_name, e_instance
 
-class SubstitutionCallbackMapper(IdentityMapper):
-    def __init__(self, names, func):
-        self.names = names
-        self.func = func
+        return None
 
     def map_variable(self, expr):
-        if expr.name not in self.names:
+        parsed_name = self.parse_name(expr)
+        if parsed_name is None:
             return IdentityMapper.map_variable(self, expr)
 
-        result = self.func(expr, expr.name, (), self.rec)
+        name, instance = parsed_name
+
+        result = self.func(expr, name, instance, (), self.rec)
         if result is None:
             return IdentityMapper.map_variable(self, expr)
         else:
             return result
 
+    def map_lookup(self, expr):
+        parsed_name = self.parse_name(expr)
+        if parsed_name is None:
+            return IdentityMapper.map_lookup(self, expr)
+
+        name, instance = parsed_name
+
+        result = self.func(expr, name, instance, (), self.rec)
+        if result is None:
+            return IdentityMapper.map_lookup(self, expr)
+        else:
+            return result
+
     def map_call(self, expr):
-        from pymbolic.primitives import Variable
-        if (not isinstance(expr.function, Variable)
-                or expr.function.name not in self.names):
-            return IdentityMapper.map_variable(self, expr)
+        parsed_name = self.parse_name(expr.function)
+        if parsed_name is None:
+            return IdentityMapper.map_call(self, expr)
+
+        name, instance = parsed_name
 
-        result = self.func(expr, expr.function.name, expr.parameters, self.rec)
+        result = self.func(expr, name, instance, expr.parameters, self.rec)
         if result is None:
             return IdentityMapper.map_call(self, expr)
         else:
@@ -647,6 +673,30 @@ class SubstitutionCallbackMapper(IdentityMapper):
 
 # }}}
 
+# {{{ parametrized substitutor
+
+class ParametrizedSubstitutor(object):
+    def __init__(self, rules):
+        self.rules = rules
+
+    def __call__(self, expr):
+        def expand_if_known(expr, name, instance, args, rec):
+            rule = self.rules[name]
+            if len(rule.arguments) != len(args):
+                raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
+                        % (name, len(args), len(rule.arguments), ))
+
+            from pymbolic.mapper.substitutor import make_subst_func
+            subst_map = SubstitutionMapper(make_subst_func(
+                dict(zip(rule.arguments, args))))
+
+            return rec(subst_map(rule.expression))
+
+        scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known)
+        return scm(expr)
+
+# }}}
+
 # {{{ wildcard -> unique variable mapper
 
 class WildcardToUniqueVariableMapper(IdentityMapper):
-- 
GitLab