From d183034be265017d595a2d20968aef6bd7d9268a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 3 Nov 2011 01:12:23 -0400
Subject: [PATCH] Refactor CSE handling to allow the user to specify the lead
 expression.

---
 loopy/__init__.py   |  15 +++--
 loopy/cse.py        | 142 ++++++++++++++++++++++----------------------
 loopy/kernel.py     |  11 ++--
 loopy/preprocess.py |   7 +--
 loopy/symbolic.py   |  21 ++++++-
 test/test_linalg.py |   2 +-
 test/test_sem.py    |  26 ++++----
 7 files changed, 122 insertions(+), 102 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 4422f18f4..d952538a3 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -412,24 +412,24 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
 
 # {{{ convenience: add_prefetch
 
-def add_prefetch(kernel, var_name, fetch_dims=[], new_inames=None, default_tag="l.auto"):
+def add_prefetch(kernel, var_name, fetch_dims=[], lead_expr=None,
+        new_inames=None, default_tag="l.auto"):
     used_cse_tags = set()
     def map_cse(expr, rec):
         used_cse_tags.add(expr.tag)
         rec(expr.child)
 
-    new_cse_tags = set()
-
     def get_unique_cse_tag():
         from loopy.tools import generate_unique_possibilities
         for cse_tag in generate_unique_possibilities(prefix="fetch_"+var_name):
             if cse_tag not in used_cse_tags:
                 used_cse_tags.add(cse_tag)
-                new_cse_tags.add(cse_tag)
                 return cse_tag
 
+    cse_tag = get_unique_cse_tag()
+
     from loopy.symbolic import VariableFetchCSEMapper
-    vf_cse_mapper = VariableFetchCSEMapper(var_name, get_unique_cse_tag)
+    vf_cse_mapper = VariableFetchCSEMapper(var_name, lambda: cse_tag)
     kernel = kernel.copy(instructions=[
             insn.copy(expression=vf_cse_mapper(insn.expression))
             for insn in kernel.instructions])
@@ -439,9 +439,8 @@ def add_prefetch(kernel, var_name, fetch_dims=[], new_inames=None, default_tag="
     else:
         dtype = kernel.temporary_variables[var_name].dtype
 
-    for cse_tag in new_cse_tags:
-        kernel = realize_cse(kernel, cse_tag, dtype, fetch_dims,
-                new_inames=new_inames, default_tag=default_tag)
+    kernel = realize_cse(kernel, cse_tag, dtype, fetch_dims, lead_expr=lead_expr,
+            new_inames=new_inames, default_tag=default_tag)
 
     return kernel
 
diff --git a/loopy/cse.py b/loopy/cse.py
index 0bcceee64..e76145292 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -11,7 +11,7 @@ from pymbolic import var
 
 
 
-def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse):
+def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse_tag, lead_expr):
     from loopy.kernel import (LocalIndexTagBase, GroupIndexTag, IlpTag)
 
     if isinstance(tag, LocalIndexTagBase):
@@ -32,7 +32,7 @@ def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse):
                     "inherit this iname, which would lead to a write race. "
                     "A likely solution of this problem is to also duplicate this "
                     "iname."
-                    % (cse.prefix, iname, tag))
+                    % (cse_tag, iname, tag))
 
     if iname in duplicate_inames and kind == "g":
         raise RuntimeError("duplicating the iname '%s' into "
@@ -47,7 +47,7 @@ def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse):
     if iname in duplicate_inames:
         raise RuntimeError("duplicating an iname ('%s') "
                 "that the CSE ('%s') does not depend on "
-                "does not make sense" % (iname, cse.child))
+                "does not make sense" % (iname, lead_expr))
 
 
 
@@ -82,6 +82,8 @@ def solve_affine_equations_for_lhs(targets, equations, parameters):
     # occur with a coefficient of 1 on the lhs, and with no other
     # targets on that lhs.
 
+    assert isinstance(targets, (list, tuple)) # had better be ordered
+
     from loopy.symbolic import CoefficientCollector
     coeff_coll = CoefficientCollector()
 
@@ -129,28 +131,30 @@ def solve_affine_equations_for_lhs(targets, equations, parameters):
 
 
 
-def process_cses(kernel, lead_csed, cse_descriptors):
-    from pymbolic.mapper.unifier import BidirectionalUnifier
+def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
+    if not independent_inames:
+        for csed in cse_descriptors:
+            csed.lead_index_exprs = []
+        return None
 
-    # {{{ parameter set/dependency finding
+    from loopy.symbolic import BidirectionalUnifier
 
-    from loopy.symbolic import DependencyMapper
-    internal_dep_mapper = DependencyMapper(composite_leaves=False)
+    ind_inames_set = set(independent_inames)
 
-    def get_deps(expr):
-        return set(dep.name for dep in internal_dep_mapper(expr))
+    # {{{ parameter set/dependency finding
 
     # Everything that is not one of the duplicate/independent inames
     # is turned into a parameter.
 
-    lead_csed.independent_inames = set(lead_csed.independent_inames)
-    lead_deps = get_deps(lead_csed.cse.child) & kernel.all_inames()
-    params = lead_deps - set(lead_csed.independent_inames)
+    from loopy.symbolic import get_dependencies
+
+    lead_deps = get_dependencies(lead_expr) & kernel.all_inames()
+    params = lead_deps - ind_inames_set
 
     # }}}
 
     lead_domain = to_parameters_or_project_out(params,
-            lead_csed.independent_inames, kernel.domain)
+            ind_inames_set, kernel.domain)
     lead_space = lead_domain.get_space()
 
     footprint = lead_domain
@@ -159,7 +163,7 @@ def process_cses(kernel, lead_csed, cse_descriptors):
     for csed in cse_descriptors:
         # {{{ find dependencies
 
-        cse_deps = get_deps(csed.cse.child) & kernel.all_inames()
+        cse_deps = get_dependencies(csed.cse.child) & kernel.all_inames()
         csed.independent_inames = cse_deps - params
 
         # }}}
@@ -167,12 +171,16 @@ def process_cses(kernel, lead_csed, cse_descriptors):
         # {{{ find unifier
 
         unif = BidirectionalUnifier(
-                lhs_mapping_candidates=lead_csed.independent_inames,
+                lhs_mapping_candidates=ind_inames_set,
                 rhs_mapping_candidates=csed.independent_inames)
-        unifiers = unif(lead_csed.cse.child, csed.cse.child)
+        unifiers = unif(lead_expr, csed.cse.child)
         if not unifiers:
             raise RuntimeError("Unable to unify  "
-            "CSEs '%s' and '%s'" % (lead_csed.cse.child, csed.cse.child))
+            "CSEs '%s' and '%s' (with lhs candidates '%s' and rhs candidates '%s')" % (
+                lead_expr, csed.cse.child,
+                ",".join(unif.lhs_mapping_candidates),
+                ",".join(unif.rhs_mapping_candidates)
+                ))
 
         # }}}
 
@@ -232,7 +240,7 @@ def process_cses(kernel, lead_csed, cse_descriptors):
             if not var_map.is_injective():
                 raise RuntimeError("In CSEs '%s' and '%s': "
                         "cannot find lead indices uniquely"
-                        % (lead_csed.cse.child, csed.cse.child))
+                        % (lead_expr, csed.cse.child))
 
             lead_index_set = restr_rhs_map.domain()
 
@@ -244,7 +252,7 @@ def process_cses(kernel, lead_csed, cse_descriptors):
             if not lead_index_set.is_subset(lead_domain):
                 raise RuntimeError("Index range of CSE '%s' does not cover a "
                         "subset of lead CSE '%s'"
-                        % (csed.cse.child, lead_csed.cse.child))
+                        % (csed.cse.child, lead_expr))
 
             found_good_unifier = True
 
@@ -252,14 +260,14 @@ def process_cses(kernel, lead_csed, cse_descriptors):
 
         if not found_good_unifier:
             raise RuntimeError("No valid unifier for '%s' and '%s'"
-                    % (csed.cse.child, lead_csed.cse.child))
+                    % (csed.cse.child, lead_expr))
 
         uni_recs.append(unifier)
 
         # {{{ solve for lead indices
 
         csed.lead_index_exprs = solve_affine_equations_for_lhs(
-                lead_csed.independent_inames,
+                independent_inames,
                 unifier.equations, params)
 
         # }}}
@@ -270,9 +278,8 @@ def process_cses(kernel, lead_csed, cse_descriptors):
 
 
 
-def make_compute_insn(kernel, lead_csed, target_var_name,
-        independent_inames, new_inames, ind_iname_to_tag):
-    insn = lead_csed.insn
+def make_compute_insn(kernel, cse_tag, lead_expr, target_var_name,
+        independent_inames, new_inames, ind_iname_to_tag, insn):
 
     # {{{ decide whether to force a dep
 
@@ -280,10 +287,11 @@ def make_compute_insn(kernel, lead_csed, target_var_name,
 
     from loopy.symbolic import IndexVariableFinder
     dependencies = IndexVariableFinder(
-            include_reduction_inames=False)(lead_csed.cse.child)
+            include_reduction_inames=False)(lead_expr)
 
     parent_inames = insn.all_inames() | insn.reduction_inames()
-    assert dependencies <= parent_inames
+    #print dependencies, parent_inames
+    #assert dependencies <= parent_inames
 
     for iname in parent_inames:
         if iname in independent_inames:
@@ -292,7 +300,7 @@ def make_compute_insn(kernel, lead_csed, target_var_name,
             tag = kernel.iname_to_tag.get(iname)
 
         check_cse_iname_deps(
-                iname, independent_inames, tag, dependencies, lead_csed.cse)
+                iname, independent_inames, tag, dependencies, cse_tag, lead_expr)
 
     # }}}
 
@@ -309,9 +317,9 @@ def make_compute_insn(kernel, lead_csed, target_var_name,
         dict(
             (old_iname, var(new_iname))
             for old_iname, new_iname in zip(independent_inames, new_inames))))
-    new_inner_expr = subst_map(lead_csed.cse.child)
+    new_inner_expr = subst_map(lead_expr)
 
-    insn_prefix = lead_csed.cse.prefix
+    insn_prefix = cse_tag
     if insn_prefix is None:
         insn_prefix = "cse"
     from loopy.kernel import Instruction
@@ -325,13 +333,16 @@ def make_compute_insn(kernel, lead_csed, target_var_name,
 
 
 def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
-        ind_iname_to_tag={}, new_inames=None, default_tag="l.auto",
-        follow_tag=None):
+        lead_expr=None, ind_iname_to_tag={}, new_inames=None, default_tag="l.auto"):
     """
     :arg independent_inames: which inames are supposed to be separate loops
         in the CSE. Also determines index order of temporary array.
     """
 
+    if isinstance(lead_expr, str):
+        from pymbolic import parse
+        lead_expr = parse(lead_expr)
+
     if not set(independent_inames) <= kernel.all_inames():
         raise ValueError("In CSE realization for '%s': "
                 "cannot make inames '%s' independent--"
@@ -387,14 +398,10 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     # {{{ gather cse descriptors
 
-    eligible_tags = [cse_tag]
-    if follow_tag is not None:
-        eligible_tags.append(follow_tag)
-
     cse_descriptors = []
 
     def gather_cses(cse, rec):
-        if cse.prefix not in eligible_tags:
+        if cse.prefix != cse_tag:
             rec(cse.child)
             return
 
@@ -415,25 +422,23 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
     if not cse_descriptors:
         raise RuntimeError("no CSEs tagged '%s' found" % cse_tag)
 
-    lead_cse_indices = [i for i, csed in enumerate(cse_descriptors) 
-            if csed.cse.prefix == cse_tag]
-    if follow_tag is not None:
-        if len(lead_cse_indices) != 1:
-            raise RuntimeError("%d lead CSEs (should be exactly 1) found for tag '%s'"
-                    % (len(lead_cse_indices), cse_tag))
-
-        lead_idx, = lead_cse_indices
-    else:
-        # pick a lead CSE at random
-        lead_idx = 0
+    if lead_expr is None:
+        from loopy.symbolic import get_dependencies
+        for csed in cse_descriptors:
+            if set(independent_inames) <= get_dependencies(csed.cse.child):
+                # pick the first cse that has the required inames as the lead expression
+                lead_expr = csed.cse.child
+                break
 
-    lead_csed = cse_descriptors.pop(lead_idx)
-    lead_csed.independent_inames = independent_inames
+        if lead_expr is None:
+            raise RuntimeError("could not find a suitable 'lead' CSE that depends on "
+                    "inames '%s'" % ",".join(independent_inames))
 
     # }}}
 
     # FIXME: Do something with the footprint
-    footprint = process_cses(kernel, lead_csed, cse_descriptors)
+    # (CAUTION: Can be None if no independent_inames)
+    footprint = process_cses(kernel, lead_expr, independent_inames, cse_descriptors)
 
     # {{{ set up temp variable
 
@@ -460,29 +465,26 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
     # }}}
 
     compute_insn = make_compute_insn(
-            kernel, lead_csed, target_var_name,
-            independent_inames, new_inames, ind_iname_to_tag)
+            kernel, cse_tag, lead_expr, target_var_name,
+            independent_inames, new_inames, ind_iname_to_tag,
+            # pick one insn at random for dep check
+            cse_descriptors[0].insn)
 
     # {{{ substitute variable references into instructions
 
     def subst_cses(cse, rec):
-        if cse is lead_csed.cse:
-            csed = lead_csed
-
-            lead_indices = [var(iname) for iname in independent_inames]
-        else:
-            found = False
-            for csed in cse_descriptors:
-                if cse is csed.cse:
-                    found = True
-                    break
-
-            if not found:
-                from pymbolic.primitives import CommonSubexpression
-                return CommonSubexpression(
-                        rec(cse.child), cse.prefix)
-
-            lead_indices = csed.lead_index_exprs
+        found = False
+        for csed in cse_descriptors:
+            if cse is csed.cse:
+                found = True
+                break
+
+        if not found:
+            from pymbolic.primitives import CommonSubexpression
+            return CommonSubexpression(
+                    rec(cse.child), cse.prefix)
+
+        lead_indices = csed.lead_index_exprs
 
         new_outer_expr = var(target_var_name)
         if lead_indices:
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 1fbea5493..ee1bcba47 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -289,7 +289,7 @@ class Instruction(Record):
         elif self.boostable == False:
             result += " (not boostable)"
         elif self.boostable is None:
-            result += " (boostability unknown)"
+            pass
         else:
             raise RuntimeError("unexpected value for Instruction.boostable")
 
@@ -329,9 +329,8 @@ class Instruction(Record):
 
     @memoize_method
     def get_read_var_names(self):
-        from loopy.symbolic import DependencyMapper
-        return set(var.name for var in
-                DependencyMapper(composite_leaves=False)(self.expression))
+        from loopy.symbolic import get_dependencies
+        return get_dependencies(self.expression)
 
 # }}}
 
@@ -774,9 +773,9 @@ class LoopKernel(Record):
         for insn in self.instructions:
             all_inames_by_insns |= insn.all_inames()
 
-        if all_inames_by_insns != self.all_inames():
+        if not all_inames_by_insns <= self.all_inames():
             raise RuntimeError("inames collected from instructions (%s) "
-                    "do not match domain inames (%s)"
+                    "that are not present in domain (%s)"
                     % (", ".join(sorted(all_inames_by_insns)), 
                         ", ".join(sorted(self.all_inames()))))
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 078bb71a4..eede44a43 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -14,10 +14,7 @@ def mark_local_temporaries(kernel):
 
     writers = find_accessors(kernel, readers=False)
 
-    from loopy.symbolic import DependencyMapper
-    dm = DependencyMapper(composite_leaves=False)
-    def get_deps(expr):
-        return set(var.name for var in dm(expr))
+    from loopy.symbolic import get_dependencies
 
     for temp_var in kernel.temporary_variables.itervalues():
         my_writers = writers[temp_var.name]
@@ -27,7 +24,7 @@ def mark_local_temporaries(kernel):
             insn = kernel.id_to_insn[insn_id]
             has_local_parallel_write = has_local_parallel_write or any(
                     isinstance(kernel.iname_to_tag.get(iname), LocalIndexTagBase)
-                    for iname in get_deps(insn.get_assignee_indices())
+                    for iname in get_dependencies(insn.get_assignee_indices())
                     & kernel.all_inames())
 
         new_temp_vars[temp_var.name] = temp_var.copy(
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 210917484..6ecf02360 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -15,6 +15,9 @@ from pymbolic.mapper.stringifier import \
         StringifyMapper as StringifyMapperBase
 from pymbolic.mapper.dependency import \
         DependencyMapper as DependencyMapperBase
+from pymbolic.mapper.unifier import BidirectionalUnifier \
+        as BidirectionalUnifierBase
+
 import numpy as np
 import islpy as isl
 from islpy import dim_type
@@ -79,7 +82,17 @@ class StringifyMapper(StringifyMapperBase):
 
 class DependencyMapper(DependencyMapperBase):
     def map_reduction(self, expr):
-        return set(expr.inames) | self.rec(expr.expr)
+        return self.rec(expr.expr)
+
+class BidirectionalUnifier(BidirectionalUnifierBase):
+    def map_reduction(self, expr, other, unis):
+        if not isinstance(other, type(expr)):
+            return self.treat_mismatch(expr, other, unis)
+        if (expr.inames != other.inames
+                or type(expr.operation) != type(other.operation)):
+            return []
+
+        return self.rec(expr.expr, other.expr, unis)
 
 # }}}
 
@@ -632,6 +645,12 @@ class PrimeAdder(IdentityMapper):
 
 # }}}
 
+def get_dependencies(expr):
+    from loopy.symbolic import DependencyMapper
+    dep_mapper = DependencyMapper(composite_leaves=False)
+
+    return set(dep.name for dep in dep_mapper(expr))
+
 
 
 # vim: foldmethod=marker
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 7ae2a1d4d..40641e9c2 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -635,7 +635,7 @@ def test_image_matrix_mul_ilp(ctx_factory):
     # conflict-free
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"],
-            ["b_j_io", "b_j_ii", "b_k_i"])
+            new_inames=["b_j_io", "b_j_ii", "b_k_i"])
     knl = lp.join_dimensions(knl, ["b_j_io", "b_j_ii"])
 
     kernel_gen = lp.generate_loop_schedules(knl)
diff --git a/test/test_sem.py b/test/test_sem.py
index 38426cdb0..6b3ac68b8 100644
--- a/test/test_sem.py
+++ b/test/test_sem.py
@@ -273,8 +273,6 @@ def test_sem_3d(ctx_factory):
     dtype = np.float32
     ctx = ctx_factory()
     order = "C"
-    queue = cl.CommandQueue(ctx,
-            properties=cl.command_queue_properties.PROFILING_ENABLE)
 
     n = 8
 
@@ -285,11 +283,11 @@ def test_sem_3d(ctx_factory):
 
     # K - run-time symbolic
     knl = lp.make_kernel(ctx.devices[0],
-            "[K] -> {[i,j,k,e,m]: 0<=i,j,k,m<%d and 0<=e<K}" % n,
+            "[K] -> {[i,j,k,e,m,o,gi]: 0<=i,j,k,m,o<%d and 0<=e<K and 0<=gi<6}" % n,
             [
-                "CSE: ur(i,j,k) = sum_float32(@m, D[i,m]*u[e,m,j,k])",
-                "CSE: us(i,j,k) = sum_float32(@m, D[j,m]*u[e,i,m,k])",
-                "CSE: ut(i,j,k) = sum_float32(@m, D[k,m]*u[e,i,j,m])",
+                "CSE: ur(i,j,k) = sum_float32(@o, D[i,o]*u[e,o,j,k])",
+                "CSE: us(i,j,k) = sum_float32(@o, D[j,o]*u[e,i,o,k])",
+                "CSE: ut(i,j,k) = sum_float32(@o, D[k,o]*u[e,i,j,o])",
 
                 "lap[e,i,j,k]  = "
                 "  sum_float32(m, D[m,i]*(G[0,e,m,j,k]*ur(m,j,k) + G[1,e,m,j,k]*us(m,j,k) + G[2,e,m,j,k]*ut(m,j,k)))"
@@ -305,16 +303,22 @@ def test_sem_3d(ctx_factory):
             ],
             name="semlap", assumptions="K>=1")
 
-    #knl = lp.realize_cse(knl, "D", np.float32, ["i_dr", "m_dr"])
-    #knl = lp.realize_cse(knl, "D", np.float32, ["i_dr", "m_dr"])
-    #knl = lp.realize_cse(knl, "u", np.float32, ["m_dr", "j_dr", "k_dr"])
-    #knl = lp.add_prefetch(knl, "G", ["m", "j", "k"])
+
+    knl = lp.add_prefetch(knl, "G", ["gi", "m", "j", "k"], "G[gi,e,m,j,k]")
+    knl = lp.add_prefetch(knl, "D", ["m", "j"])
+    knl = lp.add_prefetch(knl, "u", ["i", "j", "k"], "u[e,i,j,k]")
+    knl = lp.realize_cse(knl, "ur", np.float32, ["k", "j", "m"])
+    knl = lp.realize_cse(knl, "us", np.float32, ["i", "m", "k"])
+    knl = lp.realize_cse(knl, "ut", np.float32, ["i", "j", "m"])
 
     seq_knl = knl
+    print seq_knl
+    #print lp.preprocess_kernel(seq_knl)
+    1/0
 
     knl = lp.split_dimension(knl, "e", 16, outer_tag="g.0")#, slabs=(0, 1))
     #knl = lp.split_dimension(knl, "e_inner", 4, inner_tag="ilp")
-    print knl
+
     knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1"))
 
     kernel_gen = lp.generate_loop_schedules(knl,
-- 
GitLab