From 5485e5758384496bfa8870cf0338748b17dd5bd5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 7 Nov 2011 11:46:19 -0500
Subject: [PATCH] Rewrite CSEs in terms of unification templates.

---
 MEMO                      |   8 +-
 loopy/__init__.py         |   6 +-
 loopy/cse.py              | 553 +++++++++++++++++++++++++-------------
 loopy/kernel.py           |   7 +-
 loopy/symbolic.py         |  23 +-
 test/test_fem_assembly.py |   2 +
 test/test_linalg.py       |  23 +-
 test/test_loopy.py        |  64 ++++-
 test/test_sem.py          |   2 +-
 9 files changed, 471 insertions(+), 217 deletions(-)

diff --git a/MEMO b/MEMO
index 1866105af..fc33a978c 100644
--- a/MEMO
+++ b/MEMO
@@ -43,12 +43,6 @@ To-do
 
 - Pick not just axis 0, but all axes by lowest available stride
 
-- Unidirectional unification
-
-- Unification Wildcards
-
-- No looking at the lead domain?
-
 - Fix all tests
 
 - Deal with equality constraints.
@@ -57,6 +51,8 @@ To-do
 Future ideas
 ^^^^^^^^^^^^
 
+- Barriers for data exchanged via global vars?
+
 - Float4 joining on fetch/store?
 
 - How can one automatically generate something like microblocks?
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 379f1c3b8..9e6f16fa6 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -157,7 +157,7 @@ def make_kernel(*args, **kwargs):
             from pymbolic.primitives import Variable
             for index_expr in insn.get_assignee_indices():
                 if (not isinstance(index_expr, Variable)
-                        or not index_expr.name in knl.insn_inames(insn)):
+                        or not index_expr.name in knl.all_inames()):
                     raise RuntimeError(
                             "only plain inames are allowed in "
                             "the lvalue index when declaring the "
@@ -419,7 +419,7 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
 
 # {{{ convenience: add_prefetch
 
-def add_prefetch(kernel, var_name, fetch_dims=[], lead_expr=None,
+def add_prefetch(kernel, var_name, fetch_dims=[], uni_template=None,
         new_inames=None, default_tag="l.auto"):
     used_cse_tags = set()
     def map_cse(expr, rec):
@@ -446,7 +446,7 @@ def add_prefetch(kernel, var_name, fetch_dims=[], lead_expr=None,
     else:
         dtype = kernel.temporary_variables[var_name].dtype
 
-    kernel = realize_cse(kernel, cse_tag, dtype, fetch_dims, lead_expr=lead_expr,
+    kernel = realize_cse(kernel, cse_tag, dtype, fetch_dims, uni_template=uni_template,
             new_inames=new_inames, default_tag=default_tag)
 
     return kernel
diff --git a/loopy/cse.py b/loopy/cse.py
index 2fe4340ab..d9fa5656a 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -2,7 +2,8 @@ from __future__ import division
 
 import islpy as isl
 from islpy import dim_type
-from loopy.kernel import AutoFitLocalIndexTag
+from loopy.symbolic import get_dependencies, SubstitutionMapper
+from pymbolic.mapper.substitutor import make_subst_func
 import numpy as np
 
 from pytools import Record
@@ -11,7 +12,7 @@ from pymbolic import var
 
 
 
-def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse_tag, lead_expr):
+def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse_tag, uni_template):
     from loopy.kernel import (LocalIndexTagBase, GroupIndexTag, IlpTag)
 
     if isinstance(tag, LocalIndexTagBase):
@@ -47,14 +48,14 @@ def check_cse_iname_deps(iname, duplicate_inames, tag, dependencies, cse_tag, le
     if iname in duplicate_inames:
         raise RuntimeError("duplicating an iname ('%s') "
                 "that the CSE ('%s') does not depend on "
-                "does not make sense" % (iname, lead_expr))
+                "does not make sense" % (iname, uni_template))
 
 
 
 
 class CSEDescriptor(Record):
     __slots__ = ["insn", "cse", "independent_inames",
-            "lead_index_exprs"]
+            "unif_var_dict"]
 
 
 
@@ -77,109 +78,168 @@ def to_parameters_or_project_out(param_inames, set_inames, set):
 
 
 
-def solve_affine_equations_for_lhs(targets, equations, parameters):
-    # Not a very good solver: The desired variable must already
-    # occur with a coefficient of 1 on the lhs, and with no other
-    # targets on that lhs.
+def gaussian_elimination(mat, rhs):
+    m, n = mat.shape
+    i = 0
+    j = 0
 
-    assert isinstance(targets, (list, tuple)) # had better be ordered
+    from pymbolic.algorithm import lcm, gcd_many
+
+    while i < m and j < n:
+        # {{{ find pivot in column j, starting in row i
+
+        nonz_row = None
+        for k in range(i, m):
+            if mat[k,j]:
+                nonz_row = k
+                break
+
+        # }}}
+
+        if nonz_row is not None:
+            # swap rows i and nonz
+            mat[i], mat[nonz_row] = \
+                    (mat[nonz_row].copy(), mat[i].copy())
+            rhs[i], rhs[nonz_row] = \
+                    (rhs[nonz_row].copy(), rhs[i].copy())
+
+            for u in range(0, m):
+                if u == i:
+                    continue
+                if not mat[u, j]:
+                    # already 0
+                    continue
+
+                l = lcm(mat[u, j], mat[i, j])
+                u_fac = l//mat[u, j]
+                i_fac = l//mat[i, j]
+
+                mat[u] = u_fac*mat[u] - i_fac*mat[i]
+                rhs[u] = u_fac*rhs[u] - i_fac*rhs[i]
+
+                assert mat[u, j] == 0
+
+            i += 1
+
+        j += 1
+
+    for i in range(m):
+        g = gcd_many(*(
+            [a for a in mat[i] if a]
+            +
+            [a for a in rhs[i] if a]))
+
+        mat[i] //= g
+        rhs[i] //= g
+
+    return mat, rhs
+
+
+
+
+def solve_affine_equations_for(targets, equations):
+    # fix an order for targets
+    targets_list = list(targets)
+    target_idx_lut = dict((tgt_name, idx)
+            for idx, tgt_name in enumerate(targets_list))
+
+    # Find non-target variables, fix order for them
+    # Last non-target is constant.
+    nontargets = set()
+    for lhs, rhs in equations:
+        nontargets.update(get_dependencies(lhs) - targets)
+        nontargets.update(get_dependencies(rhs) - targets)
+    nontargets_list = list(nontargets)
+    nontarget_idx_lut = dict((var_name, idx)
+            for idx, var_name in enumerate(nontargets_list))
 
     from loopy.symbolic import CoefficientCollector
     coeff_coll = CoefficientCollector()
 
-    target_values = {}
+    # {{{ build matrix and rhs
 
-    for lhs, rhs in equations:
-        lhs_coeffs = coeff_coll(lhs)
-        rhs_coeffs = coeff_coll(rhs)
+    mat = np.zeros((len(equations), len(targets)), dtype=object)
+    rhs_mat = np.zeros((len(equations), len(nontargets)+1), dtype=object)
 
-        def shift_to_rhs(key):
-            rhs_coeffs[key] = rhs_coeffs.get(key, 0) - lhs_coeffs[key]
-            del lhs_coeffs[key]
+    for i_eqn, (lhs, rhs) in enumerate(equations):
+        for lhs_factor, coeffs in [(1, coeff_coll(lhs)), (-1, coeff_coll(rhs))]:
+            for key, coeff in coeffs.iteritems():
+                if key in targets:
+                    mat[i_eqn, target_idx_lut[key]] = lhs_factor*coeff
+                elif key in nontargets:
+                    rhs_mat[i_eqn, nontarget_idx_lut[key]] = -lhs_factor*coeff
+                elif key == 1:
+                    rhs_mat[i_eqn, -1] = -lhs_factor*coeff
+                else:
+                    raise ValueError("key '%s' not understood" % key)
 
-        for key in list(lhs_coeffs.iterkeys()):
-            if key in targets:
-                continue
-            elif key in parameters or key == 1:
-                shift_to_rhs(key)
-            else:
-                raise RuntimeError("unexpected key")
-
-        if len(lhs_coeffs) > 1:
-            raise RuntimeError("comically unable to solve '%s = %s' "
-                    "for one of the target variables '%s'"
-                    % (lhs, rhs, ",".join(targets)))
-
-        (tgt_name, coeff), = lhs_coeffs.iteritems()
-        if coeff != 1:
-            raise RuntimeError("comically unable to solve '%s = %s' "
-                    "for one of the target variables '%s'"
-                    % (lhs, rhs, ",".join(targets)))
-
-        solution = 0
-        for key, coeff in rhs_coeffs.iteritems():
-            if key == 1:
-                solution += coeff
-            else:
-                solution += coeff*var(key)
+    # }}}
 
-        assert tgt_name not in target_values
-        target_values[tgt_name] = solution
 
-    return [target_values[tname] for tname in targets]
+    mat, rhs_mat = gaussian_elimination(mat, rhs_mat)
 
+    # No need to check for overdetermined system: The case where the
+    # map is empty is treated sufficiently by isl.
 
+    result = {}
+    for j, target in enumerate(targets_list):
+        nonz_row = np.where(mat[:, j])
+        if len(nonz_row) != 1:
+            raise RuntimeError("cannot uniquely solve for '%s'" % target)
 
+        (nonz_row,), = nonz_row
 
-def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
-    if not independent_inames:
-        for csed in cse_descriptors:
-            csed.lead_index_exprs = []
-        return None
+        if abs(mat[nonz_row, j]) != 1:
+            raise RuntimeError("division with remainder in linear solve for '%s'"
+                    % target)
+        div = mat[nonz_row, j]
 
-    from loopy.symbolic import BidirectionalUnifier
+        target_val = int(rhs_mat[nonz_row, -1]) // div
+        for nontarget, coeff in zip(nontargets_list, rhs_mat[nonz_row]):
+            target_val += (int(coeff) // div) * var(nontarget)
 
-    ind_inames_set = set(independent_inames)
+        result[target] = target_val
 
-    # {{{ parameter set/dependency finding
+    if 0:
+        for lhs, rhs in equations:
+            print lhs, '=', rhs
+        print "-------------------"
+        for lhs, rhs in result.iteritems():
+            print lhs, '=', rhs
 
-    # Everything that is not one of the duplicate/independent inames
-    # is turned into a parameter.
+    return result
 
-    from loopy.symbolic import get_dependencies
 
-    lead_deps = get_dependencies(lead_expr) & kernel.all_inames()
-    params = lead_deps - ind_inames_set
 
-    # }}}
 
-    lead_domain = to_parameters_or_project_out(params,
-            ind_inames_set, kernel.domain)
-    lead_space = lead_domain.get_space()
+def process_cses(kernel, uni_template,
+        independent_inames, matching_vars, cse_descriptors):
+    if not independent_inames:
+        for csed in cse_descriptors:
+            csed.lead_index_exprs = []
+        return None
 
-    footprint = lead_domain
+    from loopy.symbolic import UnidirectionalUnifier
 
-    uni_recs = []
-    for csed in cse_descriptors:
-        # {{{ find dependencies
+    ind_inames_set = set(independent_inames)
 
-        cse_deps = get_dependencies(csed.cse.child) & kernel.all_inames()
-        csed.independent_inames = cse_deps - params
+    uni_iname_list = independent_inames + matching_vars
+    footprint = None
 
-        # }}}
+    uni_recs = []
+    matching_var_values = {}
 
+    for csed in cse_descriptors:
         # {{{ find unifier
 
-        unif = BidirectionalUnifier(
-                lhs_mapping_candidates=ind_inames_set,
-                rhs_mapping_candidates=csed.independent_inames)
-        unifiers = unif(lead_expr, csed.cse.child)
+        unif = UnidirectionalUnifier(
+                lhs_mapping_candidates=ind_inames_set | set(matching_vars))
+        unifiers = unif(uni_template, csed.cse.child)
         if not unifiers:
             raise RuntimeError("Unable to unify  "
-            "CSEs '%s' and '%s' (with lhs candidates '%s' and rhs candidates '%s')" % (
-                lead_expr, csed.cse.child,
+            "CSEs '%s' and '%s' (with lhs candidates '%s')" % (
+                uni_template, csed.cse.child,
                 ",".join(unif.lhs_mapping_candidates),
-                ",".join(unif.rhs_mapping_candidates)
                 ))
 
         # }}}
@@ -189,18 +249,13 @@ def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
         for unifier in unifiers:
             # {{{ construct, check mapping
 
-            rhs_domain = to_parameters_or_project_out(
-                    params, csed.independent_inames, kernel.domain)
-            rhs_space = rhs_domain.get_space()
+            map_space = kernel.space
+            ln = len(uni_iname_list)
+            rn = kernel.space.dim(dim_type.out)
 
-            map_space = lead_space
-            ln = lead_space.dim(dim_type.set)
-            map_space = map_space.move_dims(dim_type.in_, 0, dim_type.set, 0, ln)
-            rn = rhs_space.dim(dim_type.set)
-            map_space = map_space.add_dims(dim_type.out, rn)
-            for i in range(rhs_domain.dim(dim_type.set)):
-                map_space = map_space.set_dim_name(dim_type.out, i,
-                        rhs_domain.get_dim_name(dim_type.set, i)+"'")
+            map_space = map_space.add_dims(dim_type.in_, ln)
+            for i, iname in enumerate(uni_iname_list):
+                map_space = map_space.set_dim_name(dim_type.in_, i, iname)
 
             set_space = map_space.move_dims(
                     dim_type.out, rn,
@@ -208,11 +263,10 @@ def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
 
             var_map = None
 
-            from loopy.symbolic import aff_from_expr, PrimeAdder
-            add_primes = PrimeAdder(csed.independent_inames)
+            from loopy.symbolic import aff_from_expr
             for lhs, rhs in unifier.equations:
                 cns = isl.Constraint.equality_from_aff(
-                        aff_from_expr(set_space, lhs - add_primes(rhs)))
+                        aff_from_expr(set_space, lhs - rhs))
 
                 cns_map = isl.BasicMap.from_constraint(cns)
                 if var_map is None:
@@ -226,33 +280,27 @@ def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
 
             restr_rhs_map = (
                     isl.Map.from_basic_map(var_map)
-                    .intersect_range(rhs_domain))
+                    .intersect_range(kernel.domain))
 
             # Sanity check: If the range of the map does not recover the
             # domain of the expression, the unifier must have been no
             # good.
-            if restr_rhs_map.range() != rhs_domain:
+            if restr_rhs_map.range() != kernel.domain:
                 continue
 
             # Sanity check: Injectivity here means that unique lead indices
-            # can be found for each 
+            # can be found for each
 
             if not var_map.is_injective():
                 raise RuntimeError("In CSEs '%s' and '%s': "
                         "cannot find lead indices uniquely"
-                        % (lead_expr, csed.cse.child))
-
-            lead_index_set = restr_rhs_map.domain()
-
-            footprint = footprint.union(lead_index_set)
-
-            # FIXME: This restriction will be lifted in the future, and the
-            # footprint will instead be used as the lead domain.
+                        % (uni_template, csed.cse.child))
 
-            if not lead_index_set.is_subset(lead_domain):
-                raise RuntimeError("Index range of CSE '%s' does not cover a "
-                        "subset of lead CSE '%s'"
-                        % (csed.cse.child, lead_expr))
+            footprint_contrib = restr_rhs_map.domain()
+            if footprint is None:
+                footprint = footprint_contrib
+            else:
+                footprint = footprint.union(footprint_contrib)
 
             found_good_unifier = True
 
@@ -260,32 +308,53 @@ def process_cses(kernel, lead_expr, independent_inames, cse_descriptors):
 
         if not found_good_unifier:
             raise RuntimeError("No valid unifier for '%s' and '%s'"
-                    % (csed.cse.child, lead_expr))
+                    % (csed.cse.child, uni_template))
 
         uni_recs.append(unifier)
 
-        # {{{ solve for lead indices
+        # {{{ check that matching_vars have a unique_value
 
-        csed.lead_index_exprs = solve_affine_equations_for_lhs(
-                independent_inames,
-                unifier.equations, params)
+        csed.unif_var_dict = dict((lhs.name, rhs)
+                for lhs, rhs in unifier.equations)
+        for mv_name in matching_vars:
+            if mv_name in matching_var_values:
+                if matching_var_values[mv_name] != csed.unif_var_dict[mv_name]:
+                    raise RuntimeError("two different expressions encountered "
+                            "for matching variable: '%s' and '%s'" % (
+                                matching_var_values[mv_name], csed.unif_var_dict[mv_name]))
+            else:
+                matching_var_values[mv_name] = csed.unif_var_dict[mv_name]
 
         # }}}
 
-    return footprint.coalesce()
+        if 0:
+            # {{{ solve for lead indices
+
+            solve_targets = set()
+            for lhs, rhs in unifier.equations:
+                solve_targets.update(get_dependencies(rhs)
+                        - kernel.non_iname_variable_names())
+
+            csed.from_unified_exprs = solve_affine_equations_for(
+                    solve_targets, unifier.equations)
+
+            # }}}
 
+    return footprint, matching_var_values,
 
 
 
 
-def make_compute_insn(kernel, cse_tag, lead_expr, target_var_name,
-        independent_inames, new_inames, ind_iname_to_tag, insn):
+
+def make_compute_insn(kernel, cse_tag, uni_template,
+        target_var_name, target_var_base_indices,
+        independent_inames, ind_iname_to_tag, insn):
 
     # {{{ decide whether to force a dep
 
     from loopy.symbolic import IndexVariableFinder
     dependencies = IndexVariableFinder(
-            include_reduction_inames=False)(lead_expr)
+            include_reduction_inames=False)(uni_template)
 
     parent_inames = kernel.insn_inames(insn) | insn.reduction_inames()
     #print dependencies, parent_inames
@@ -298,25 +367,18 @@ def make_compute_insn(kernel, cse_tag, lead_expr, target_var_name,
             tag = kernel.iname_to_tag.get(iname)
 
         check_cse_iname_deps(
-                iname, independent_inames, tag, dependencies, cse_tag, lead_expr)
+                iname, independent_inames, tag, dependencies, cse_tag, uni_template)
 
     # }}}
 
     assignee = var(target_var_name)
 
-    if new_inames:
+    if independent_inames:
         assignee = assignee[tuple(
-            var(iname) for iname in new_inames
+            var(iname)-bi
+            for iname, bi in zip(independent_inames, target_var_base_indices)
             )]
 
-    from loopy.symbolic import SubstitutionMapper
-    from pymbolic.mapper.substitutor import make_subst_func
-    subst_map = SubstitutionMapper(make_subst_func(
-        dict(
-            (old_iname, var(new_iname))
-            for old_iname, new_iname in zip(independent_inames, new_inames))))
-    new_inner_expr = subst_map(lead_expr)
-
     insn_prefix = cse_tag
     if insn_prefix is None:
         insn_prefix = "cse"
@@ -324,31 +386,77 @@ def make_compute_insn(kernel, cse_tag, lead_expr, target_var_name,
     return Instruction(
             id=kernel.make_unique_instruction_id(based_on=insn_prefix+"_compute"),
             assignee=assignee,
-            expression=new_inner_expr)
+            expression=uni_template)
 
 
 
 
 def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
-        lead_expr=None, ind_iname_to_tag={}, new_inames=None, default_tag="l.auto"):
+        uni_template=None, ind_iname_to_tag={}, new_inames=None, default_tag="l.auto"):
     """
     :arg independent_inames: which inames are supposed to be separate loops
         in the CSE. Also determines index order of temporary array.
+        The variables in independent_inames refer to the unification
+        template.
+    :arg uni_template: An expression against which all targeted subexpressions
+        must unify
+
+        If None, a unification template will be chosen from among the targeted
+        CSEs. That CSE is chosen to depend on all the variables in
+        *independent_inames*.  It is an error if no such expression can be
+        found.
+
+        May contain '*' wildcards that will have to match exactly across all
+        unifications.
+
+    Process:
+
+    - Find all targeted CSEs.
+
+    - Find *uni_template* as described above.
+
+    - Turn all wildcards in *uni_template* into matching-relevant (but not
+      independent, in the sense of *independent_inames*) variables.
+
+    - Unify the CSEs with the unification template, detecting mappings
+      of template variables to variables used in the CSE.
+
+    - Find the (union) footprint of the CSEs in terms of the
+      *independent_inames*.
+
+    - Augment the kernel domain by that footprint and generate the fetch
+      instruction.
+
+    - Replace the CSEs according to the mapping detected in unification.
     """
 
-    if isinstance(lead_expr, str):
+    newly_created_var_names = set()
+
+    # {{{ replace any wildcards in uni_template with new variables
+
+    if isinstance(uni_template, str):
         from pymbolic import parse
-        lead_expr = parse(lead_expr)
+        uni_template = parse(uni_template)
+
+    def get_unique_var_name():
+        if cse_tag is None:
+            based_on = "cse_wc"
+        else:
+            based_on = cse_tag+"_wc"
+
+        result = kernel.make_unique_var_name(
+                based_on=based_on, extra_used_vars=newly_created_var_names)
+        newly_created_var_names.add(result)
+        return result
 
-    if not set(independent_inames) <= kernel.all_inames():
-        raise ValueError("In CSE realization for '%s': "
-                "cannot make iname(s) '%s' independent--"
-                "it/they don't already exist" % (
-                    cse_tag,
-                    ",".join(
-                        set(independent_inames)-kernel.all_inames())))
+    if uni_template is not None:
+        from loopy.symbolic import WildcardToUniqueVariableMapper
+        wc_map = WildcardToUniqueVariableMapper(get_unique_var_name)
+        uni_template = wc_map(uni_template)
 
-    # {{{ process parallel_inames and ind_iname_to_tag arguments
+    # }}}
+
+    # {{{ process ind_iname_to_tag argument
 
     ind_iname_to_tag = ind_iname_to_tag.copy()
 
@@ -364,35 +472,6 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     # }}}
 
-    # {{{ process new_inames argument, think of new inames for inames to be duplicated
-
-    if new_inames is None:
-        new_inames = [None] * len(independent_inames)
-
-    if len(new_inames) != len(independent_inames):
-        raise ValueError("If given, the new_inames argument must have the "
-                "same length as independent_inames")
-
-    temp_new_inames = []
-    for old_iname, new_iname in zip(independent_inames, new_inames):
-        if new_iname is None:
-            if cse_tag is not None:
-                based_on = old_iname+"_"+cse_tag
-            else:
-                based_on = old_iname
-
-            new_iname = kernel.make_unique_var_name(based_on, set(temp_new_inames))
-            assert new_iname != old_iname
-
-        temp_new_inames.append(new_iname)
-
-    new_inames = temp_new_inames
-
-    # }}}
-
-    from loopy.isl_helpers import duplicate_axes
-    new_domain = duplicate_axes(kernel.domain, independent_inames, new_inames)
-
     # {{{ gather cse descriptors
 
     cse_descriptors = []
@@ -414,28 +493,132 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     # }}}
 
-    # {{{ find/pick the lead cse
+    # {{{ find/pick a unification template
 
     if not cse_descriptors:
         raise RuntimeError("no CSEs tagged '%s' found" % cse_tag)
 
-    if lead_expr is None:
-        from loopy.symbolic import get_dependencies
+    if uni_template is None:
         for csed in cse_descriptors:
             if set(independent_inames) <= get_dependencies(csed.cse.child):
-                # pick the first cse that has the required inames as the lead expression
-                lead_expr = csed.cse.child
+                # pick the first cse that has the required inames as the unification template
+                uni_template = csed.cse.child
                 break
 
-        if lead_expr is None:
-            raise RuntimeError("could not find a suitable 'lead' CSE that depends on "
+        if uni_template is None:
+            raise RuntimeError("could not find a suitable unification template that depends on "
                     "inames '%s'" % ",".join(independent_inames))
 
     # }}}
 
-    # FIXME: Do something with the footprint
-    # (CAUTION: Can be None if no independent_inames)
-    footprint = process_cses(kernel, lead_expr, independent_inames, cse_descriptors)
+    # {{{ make sure that independent inames and kernel inames do not overlap
+
+    # (and substitute in uni_template if any variable name changes are necessary)
+
+    if set(independent_inames) & kernel.all_inames():
+        old_to_new = {}
+
+        new_independent_inames = []
+        new_ind_iname_to_tag = {}
+        for i, iname in enumerate(independent_inames):
+            if iname in kernel.all_inames():
+                based_on = iname
+                if new_inames is not None and i < len(new_inames):
+                    based_on = new_inames[i]
+
+                new_iname = kernel.make_unique_var_name(
+                        based_on=iname, extra_used_vars=newly_created_var_names)
+                old_to_new[iname] = var(new_iname)
+                newly_created_var_names.add(new_iname)
+                new_independent_inames.append(new_iname)
+                new_ind_iname_to_tag[new_iname] = ind_iname_to_tag[iname]
+            else:
+                new_independent_inames.append(iname)
+                new_ind_iname_to_tag[iname] = ind_iname_to_tag[iname]
+
+        independent_inames = new_independent_inames
+        ind_iname_to_tag = new_ind_iname_to_tag
+        uni_template = (
+                SubstitutionMapper(make_subst_func(old_to_new))
+                (uni_template))
+
+    # }}}
+
+    # {{{ deal with iname deps of uni_template that are not independent_inames
+
+    # (We call these 'matching_vars', because they have to match exactly in
+    # every CSE. As above, they might need to be renamed to make them unique
+    # within the kernel.)
+
+    matching_vars = []
+    old_to_new = {}
+
+    for iname in (get_dependencies(uni_template)
+            - set(independent_inames)
+            - kernel.non_iname_variable_names()):
+        if iname in kernel.all_inames():
+            # need to rename to be unique
+            new_iname = kernel.make_unique_var_name(
+                    based_on=iname, extra_used_vars=newly_created_var_names)
+            old_to_new[iname] = var(new_iname)
+            newly_created_var_names.add(new_iname)
+            matching_vars.append(new_iname)
+        else:
+            matching_vars.append(iname)
+
+    if old_to_new:
+        uni_template = (
+                SubstitutionMapper(make_subst_func(old_to_new))
+                (uni_template))
+
+    # }}}
+
+    # {{{ align and intersect the footprint and the domain
+
+    # (If there are independent inames, this adds extra dimensions to the domain.)
+
+    footprint, matching_var_values = process_cses(kernel, uni_template,
+            independent_inames, matching_vars,
+            cse_descriptors)
+
+    if isinstance(footprint, isl.Set):
+        footprint = footprint.coalesce()
+        footprint_bsets = footprint.get_basic_sets()
+        if len(footprint_bsets) > 1:
+            raise NotImplementedError("CSE '%s' yielded a non-convex footprint"
+                    % cse_tag)
+
+        footprint, = footprint_bsets
+
+    ndim = kernel.space.dim(dim_type.set)
+    footprint = footprint.insert_dims(dim_type.set, 0, ndim)
+    for i in range(ndim):
+        footprint = footprint.set_dim_name(dim_type.set, i,
+                kernel.space.get_dim_name(dim_type.set, i))
+
+    from islpy import align_spaces
+    new_domain = align_spaces(kernel.domain, footprint).intersect(footprint)
+
+    # set matching vars equal to their unified value, eliminate them
+    from loopy.symbolic import aff_from_expr
+
+    assert set(matching_var_values) == set(matching_vars)
+
+    for var_name, value in matching_var_values.iteritems():
+        cns = isl.Constraint.equality_from_aff(
+                aff_from_expr(new_domain.get_space(), var(var_name) - value))
+        new_domain = new_domain.add_constraint(cns)
+
+    new_domain = (new_domain
+            .eliminate(dim_type.set,
+                new_domain.dim(dim_type.set)-len(matching_vars), len(matching_vars))
+            .remove_dims(dim_type.set,
+                new_domain.dim(dim_type.set)-len(matching_vars), len(matching_vars)))
+    new_domain = new_domain.remove_redundancies()
+
+    # }}}
+
+    # }}}
 
     # {{{ set up temp variable
 
@@ -449,7 +632,7 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     target_var_base_indices, target_var_shape = \
             find_var_base_indices_and_shape_from_inames(
-                    new_domain, new_inames)
+                    new_domain, independent_inames)
 
     new_temporary_variables = kernel.temporary_variables.copy()
     new_temporary_variables[target_var_name] = TemporaryVariable(
@@ -461,9 +644,13 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     # }}}
 
+    mv_subst = SubstitutionMapper(make_subst_func(
+        dict((mv, matching_var_values[mv]) for mv in matching_vars)))
+
     compute_insn = make_compute_insn(
-            kernel, cse_tag, lead_expr, target_var_name,
-            independent_inames, new_inames, ind_iname_to_tag,
+            kernel, cse_tag, mv_subst(uni_template),
+            target_var_name, target_var_base_indices,
+            independent_inames, ind_iname_to_tag,
             # pick one insn at random for dep check
             cse_descriptors[0].insn)
 
@@ -481,11 +668,12 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
             return CommonSubexpression(
                     rec(cse.child), cse.prefix)
 
-        lead_indices = csed.lead_index_exprs
+        indices = [csed.unif_var_dict[iname]-bi
+                for iname, bi in zip(independent_inames, target_var_base_indices)]
 
         new_outer_expr = var(target_var_name)
-        if lead_indices:
-            new_outer_expr = new_outer_expr[tuple(lead_indices)]
+        if indices:
+            new_outer_expr = new_outer_expr[tuple(indices)]
 
         return new_outer_expr
         # can't nest, don't recurse
@@ -501,8 +689,7 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
     # }}}
 
     new_iname_to_tag = kernel.iname_to_tag.copy()
-    for old_iname, new_iname in zip(independent_inames, new_inames):
-        new_iname_to_tag[new_iname] = ind_iname_to_tag[old_iname]
+    new_iname_to_tag.update(ind_iname_to_tag)
 
     return kernel.copy(
             domain=new_domain,
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 408fe45fa..8ae0482a7 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -661,6 +661,11 @@ class LoopKernel(Record):
         from islpy import dim_type
         return set(self.space.get_var_dict(dim_type.set).iterkeys())
 
+    @memoize_method
+    def non_iname_variable_names(self):
+        return (set(self.arg_dict.iterkeys())
+                | set(self.temporary_variables.iterkeys()))
+
     @memoize_method
     def all_insn_inames(self):
         from loopy.symbolic import get_dependencies
@@ -778,7 +783,7 @@ class LoopKernel(Record):
             var_name = insn.get_assignee_var_name()
 
             if var_name not in admissible_vars:
-                raise RuntimeError("writing to '%s' is not allowed" % var_name)
+                raise RuntimeError("variable '%s' not declared or not allowed for writing" % var_name)
             var_names = [var_name]
 
             for var_name in var_names:
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 49e9dfdbf..427c9be5c 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -17,8 +17,8 @@ from pymbolic.mapper.stringifier import \
         StringifyMapper as StringifyMapperBase
 from pymbolic.mapper.dependency import \
         DependencyMapper as DependencyMapperBase
-from pymbolic.mapper.unifier import BidirectionalUnifier \
-        as BidirectionalUnifierBase
+from pymbolic.mapper.unifier import UnidirectionalUnifier \
+        as UnidirectionalUnifierBase
 
 import numpy as np
 import islpy as isl
@@ -98,7 +98,7 @@ class DependencyMapper(DependencyMapperBase):
         return (self.rec(expr.expr)
                 - set(Variable(iname) for iname in expr.untagged_inames))
 
-class BidirectionalUnifier(BidirectionalUnifierBase):
+class UnidirectionalUnifier(UnidirectionalUnifierBase):
     def map_reduction(self, expr, other, unis):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, unis)
@@ -480,7 +480,7 @@ def pw_aff_to_expr(pw_aff):
     (set, aff), = pieces
     return aff_to_expr(aff)
 
-def aff_from_expr(space, expr):
+def aff_from_expr(space, expr, vars_to_zero=set()):
     zero = isl.Aff.zero_on_domain(isl.LocalSpace.from_space(space))
     context = {}
     for name, (dt, pos) in space.get_var_dict().iteritems():
@@ -489,6 +489,9 @@ def aff_from_expr(space, expr):
 
         context[name] = zero.set_coefficient(dt, pos, 1)
 
+    for name in vars_to_zero:
+        context[name] = zero
+
     from pymbolic import evaluate
     return zero + evaluate(expr, context)
 
@@ -653,6 +656,18 @@ class ParametrizedSubstitutor(IdentityMapper):
 
 # }}}
 
+# {{{ wildcard -> unique variable mapper
+
+class WildcardToUniqueVariableMapper(IdentityMapper):
+    def __init__(self, unique_var_name_factory):
+        self.unique_var_name_factory = unique_var_name_factory
+
+    def map_wildcard(self, expr):
+        from pymbolic import var
+        return var(self.unique_var_name_factory())
+
+# }}}
+
 # {{{ prime-adder
 
 class PrimeAdder(IdentityMapper):
diff --git a/test/test_fem_assembly.py b/test/test_fem_assembly.py
index 72f310765..2c70c2693 100644
--- a/test/test_fem_assembly.py
+++ b/test/test_fem_assembly.py
@@ -56,6 +56,8 @@ def test_laplacian_stiffness(ctx_factory):
     knl = lp.split_dimension(knl, "K", 16, outer_tag="g.0", slabs=(0,1))
     knl = lp.split_dimension(knl, "K_inner", 4, inner_tag="ilp")
     knl = lp.tag_dimensions(knl, {"i": "l.0", "j": "l.1"})
+    knl = lp.add_prefetch(knl, 'jacInv', ["K_inner_outer", "K_inner_inner", "q"],
+            uni_template="jacInv[x,y,z,u]")
 
     kernel_gen = lp.generate_loop_schedules(knl,
             loop_priority=["K", "i", "j"])
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 40641e9c2..329384071 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -109,7 +109,10 @@ def test_axpy(ctx_factory):
                 lp.ArrayArg("z", dtype, shape="n,"),
                 lp.ScalarArg("n", np.int32, approximately=n),
                 ],
-            name="matmul")
+            name="axpy", assumptions="n>=1")
+
+    def variant_seq(knl):
+        return knl
 
     def variant_cpu(knl):
         unroll = 16
@@ -125,22 +128,26 @@ def test_axpy(ctx_factory):
         knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0")
         return knl
 
-    a = cl_random.rand(queue, n, dtype=dtype, luxury=2)
-    b = cl_random.rand(queue, n, dtype=dtype, luxury=2)
-    c = cl_array.zeros_like(a)
-    refsol = (2*a+3*b).get()
+    #x = cl_array.to_device(queue, np.random.rand(n).astype(dtype))
+    #y = cl_array.to_device(queue, np.random.rand(n).astype(dtype))
+    x = cl_random.rand(queue, n, dtype=dtype, luxury=2)
+    y = cl_random.rand(queue, n, dtype=dtype, luxury=2)
+    print np.isnan(x.get()).any()
+    1/0
+    z = cl_array.zeros_like(x)
+    refsol = (2*x+3*y).get()
 
-    for variant in [variant_cpu, variant_gpu]:
+    for variant in [variant_seq, variant_cpu, variant_gpu]:
         kernel_gen = lp.generate_loop_schedules(variant(knl),
                 loop_priority=["i_inner_outer"])
         kernel_gen = lp.check_kernels(kernel_gen, dict(n=n))
 
         def launcher(kernel, gsize, lsize, check):
-            evt = kernel(queue, gsize(n), lsize(n), 2, a.data, 3, b.data, c.data, n,
+            evt = kernel(queue, gsize(n), lsize(n), 2, x.data, 3, y.data, z.data, n,
                     g_times_l=True)
 
             if check:
-                check_error(refsol, c.get())
+                check_error(refsol, z.get())
 
             return evt
 
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 2e01144a5..f74848a63 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -85,11 +85,11 @@ def test_multi_cse(ctx_factory):
 
 
 
-def test_stencil(ctx_factory):
+def test_bad_stencil(ctx_factory):
     ctx = ctx_factory()
 
     knl = lp.make_kernel(ctx.devices[0],
-            "{[i,j]: 0<= i,j <4}",
+            "{[i,j]: 0<= i,j < 32}",
             [
                 "[i] <float32> z[i,j] = -2*cse(a[i,j])"
                     " + cse(a[i,j-1])"
@@ -97,17 +97,59 @@ def test_stencil(ctx_factory):
                     " + cse(a[i-1,j])"
                     " + cse(a[i+1,i])" # watch out: i!
                 ],
-            [lp.ArrayArg("a", np.float32, shape=(4,4,))])
+            [
+                lp.ArrayArg("a", np.float32, shape=(32,32,))
+                ])
 
-    knl = lp.split_dimension(knl, "i", 16)
+    def variant_1(knl):
+        return knl
 
-    try:
-        lp.realize_cse(knl, None, np.float32, ["i_inner", "j"])
-    except RuntimeError, e:
-        assert "does not cover a subset" in str(e)
-        pass # expected!
-    else:
-        assert False # expecting an error
+    def variant_2(knl):
+        knl = lp.split_dimension(knl, "i", 16, outer_tag="g.1", inner_tag="l.1")
+        knl = lp.realize_cse(knl, None, np.float32, ["i_inner", "j"])
+        return knl
+
+    for variant in [variant_1, variant_2]:
+        kernel_gen = lp.generate_loop_schedules(variant(knl),
+                loop_priority=["i_outer", "i_inner_0", "j_0"])
+        kernel_gen = lp.check_kernels(kernel_gen)
+
+        for knl in kernel_gen:
+            print lp.generate_code(knl)
+
+
+
+
+
+def test_stencil(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i,j]: 0<= i,j < 32}",
+            [
+                "[i] <float32> z[i,j] = -2*cse(a[i,j])"
+                    " + cse(a[i,j-1])"
+                    " + cse(a[i,j+1])"
+                    " + cse(a[i-1,j])"
+                    " + cse(a[i+1,j])" # watch out: i!
+                ],
+            [
+                lp.ArrayArg("a", np.float32, shape=(32,32,))
+                ])
+
+    def variant_3(knl):
+        knl = lp.split_dimension(knl, "i", 16, outer_tag="g.1", inner_tag="l.1")
+        knl = lp.split_dimension(knl, "j", 16, outer_tag="g.0", inner_tag="l.0")
+        knl = lp.realize_cse(knl, None, np.float32, ["i_inner", "j_inner"])
+        return knl
+
+    for variant in [variant_3]:
+        kernel_gen = lp.generate_loop_schedules(variant(knl),
+                loop_priority=["i_outer", "i_inner_0", "j_0"])
+        kernel_gen = lp.check_kernels(kernel_gen)
+
+        for knl in kernel_gen:
+            print lp.generate_code(knl)
 
 
 
diff --git a/test/test_sem.py b/test/test_sem.py
index 7212860d0..d6641db23 100644
--- a/test/test_sem.py
+++ b/test/test_sem.py
@@ -307,7 +307,7 @@ def test_sem_3d(ctx_factory):
     def add_pf(knl):
         knl = lp.add_prefetch(knl, "G", ["gi", "m", "j", "k"], "G[gi,e,m,j,k]")
         knl = lp.add_prefetch(knl, "D", ["m", "j"])
-        knl = lp.add_prefetch(knl, "u", ["i", "j", "k"], "u[e,i,j,k]")
+        knl = lp.add_prefetch(knl, "u", ["i", "j", "k"], "u[*,i,j,k]")
         knl = lp.realize_cse(knl, "ur", np.float32, ["k", "j", "m"])
         knl = lp.realize_cse(knl, "us", np.float32, ["i", "m", "k"])
         knl = lp.realize_cse(knl, "ut", np.float32, ["i", "j", "m"])
-- 
GitLab