From 724bf1c3d4189c7cb2a9d64aacef75bf8bb93454 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 31 Oct 2011 00:35:13 -0400
Subject: [PATCH] Allow multiple references to a CSE with different indices in
 each.

---
 MEMO                 |  10 +-
 loopy/__init__.py    | 283 ++-------------------
 loopy/cse.py         | 584 +++++++++++++++++++++++++++++++++++++++++++
 loopy/isl_helpers.py |   2 +-
 loopy/kernel.py      |   7 +-
 loopy/symbolic.py    |  29 ++-
 test/test_loopy.py   |  59 ++++-
 7 files changed, 700 insertions(+), 274 deletions(-)
 create mode 100644 loopy/cse.py

diff --git a/MEMO b/MEMO
index 25d5585fa..3914782d7 100644
--- a/MEMO
+++ b/MEMO
@@ -45,10 +45,16 @@ To-do
 - variable shuffle detection
   -> will need unification
 
-- Fix all tests
-
 - Automatically generate testing code vs. sequential.
 
+- For forced workgroup sizes: check that at least one iname
+  maps to it.
+
+- If isl can prove that all operands are positive, may use '/' instead of
+  'floor_div'.
+
+- Fix all tests
+
 - Deal with equality constraints.
   (These arise, e.g., when partitioning a loop of length 16 into 16s.)
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index dd8d2147c..4eb0e2c55 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -23,11 +23,19 @@ class LoopyAdvisory(UserWarning):
 from loopy.kernel import ScalarArg, ArrayArg, ImageArg
 
 from loopy.kernel import AutoFitLocalIndexTag
+from loopy.cse import realize_cse
 from loopy.preprocess import preprocess_kernel
 from loopy.schedule import generate_loop_schedules
 from loopy.compiled import CompiledKernel, drive_timing_run
 from loopy.check import check_kernels
 
+__all__ = ["ScalarArg", "ArrayArg", "ImageArg",
+        "preprocess_kernel", "generate_loop_schedules",
+        "CompiledKernel", "drive_timing_run", "check_kernels",
+        "make_kernel", "split_dimension", "join_dimensions",
+        "tag_dimensions", "realize_cse", "add_prefetch"
+        ]
+
 # }}}
 
 # {{{ kernel creation
@@ -222,13 +230,16 @@ def make_kernel(*args, **kwargs):
 
 # }}}
 
-# {{{ user-facing kernel manipulation functionality
+# {{{ dimension split
 
 def split_dimension(kernel, iname, inner_length,
         outer_iname=None, inner_iname=None,
         outer_tag=None, inner_tag=None,
         slabs=(0, 0)):
 
+    if kernel.iname_to_tag.get(iname) is not None:
+        raise RuntimeError("cannot split already tagged iname '%s'" % iname)
+
     if iname not in kernel.all_inames():
         raise ValueError("cannot split loop for unknown variable '%s'" % iname)
 
@@ -242,8 +253,8 @@ def split_dimension(kernel, iname, inner_length,
 
     def process_set(s):
         s = s.add_dims(dim_type.set, 2)
-        s.set_dim_name(dim_type.set, outer_var_nr, outer_iname)
-        s.set_dim_name(dim_type.set, inner_var_nr, inner_iname)
+        s = s.set_dim_name(dim_type.set, outer_var_nr, outer_iname)
+        s = s.set_dim_name(dim_type.set, inner_var_nr, inner_iname)
 
         from loopy.isl_helpers import make_slab
 
@@ -307,8 +318,9 @@ def split_dimension(kernel, iname, inner_length,
 
     return tag_dimensions(result, {outer_iname: outer_tag, inner_iname: inner_tag})
 
+# }}}
 
-
+# {{{ dimension join
 
 def join_dimensions(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
     """
@@ -396,8 +408,9 @@ def join_dimensions(kernel, inames, new_iname=None, tag=AutoFitLocalIndexTag()):
 
     return tag_dimensions(result, {new_iname: tag})
 
+# }}}
 
-
+# {{{ dimension tag
 
 def tag_dimensions(kernel, iname_to_tag):
     from loopy.kernel import parse_tag
@@ -432,265 +445,9 @@ def tag_dimensions(kernel, iname_to_tag):
 
     return kernel.copy(iname_to_tag=new_iname_to_tag)
 
+# }}}
 
-
-
-
-def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=None,
-        dup_iname_to_tag={}, new_inames=None, default_tag_class=AutoFitLocalIndexTag):
-    """
-    :arg duplicate_inames: which inames are supposed to be separate loops
-        in the CSE. Also determines index order of temporary array.
-    :arg parallel_inames: only a convenient interface for dup_iname_to_tag
-    """
-
-    dtype = np.dtype(dtype)
-
-    from pytools import any
-
-    # {{{ process parallel_inames and dup_iname_to_tag arguments
-
-    if parallel_inames is None:
-        # default to all-parallel
-        parallel_inames = duplicate_inames
-
-    dup_iname_to_tag = dup_iname_to_tag.copy()
-    for piname in parallel_inames:
-        dup_iname_to_tag[piname] = default_tag_class()
-
-    for diname in duplicate_inames:
-        dup_iname_to_tag.setdefault(diname, None)
-
-    if not set(dup_iname_to_tag.iterkeys()) <= set(duplicate_inames):
-        raise RuntimeError("paralleization/tag info for non-duplicated inames "
-                "may not be passed")
-
-    # here, all information is consolidated into dup_iname_to_tag
-
-    # }}}
-
-    # {{{ process new_inames argument, think of new inames for inames to be duplicated
-
-    if new_inames is None:
-        new_inames = [None] * len(duplicate_inames)
-
-    if len(new_inames) != len(duplicate_inames):
-        raise ValueError("If given, the new_inames argument must have the "
-                "same length as duplicate_inames")
-
-    temp_new_inames = []
-    for old_iname, new_iname in zip(duplicate_inames, new_inames):
-        if new_iname is None:
-            new_iname = kernel.make_unique_var_name(old_iname)
-        temp_new_inames.append(new_iname)
-
-    new_inames = temp_new_inames
-
-    old_to_new_iname = dict(zip(duplicate_inames, new_inames))
-
-    # }}}
-
-    target_var_name = kernel.make_unique_var_name(cse_tag)
-
-    from loopy.kernel import (LocalIndexTagBase, GroupIndexTag, IlpTag)
-    target_var_is_local = any(
-            isinstance(tag, LocalIndexTagBase)
-            for tag in dup_iname_to_tag.itervalues())
-
-    cse_lookup_table = {}
-
-    cse_result_insns = []
-
-    def map_cse(expr, rec):
-        if expr.prefix != cse_tag:
-            return
-
-        # FIXME stencils and variable shuffle detection would happen here
-
-        try:
-            cse_replacement, dep_id = cse_lookup_table[expr]
-        except KeyError:
-            pass
-        else:
-            return cse_replacement
-
-        if cse_result_insns:
-            raise RuntimeError("CSE tag '%s' is not unique" % cse_tag)
-
-        # {{{ decide what to do with each iname
-
-        forced_iname_deps = set()
-
-        from loopy.symbolic import IndexVariableFinder
-        dependencies = IndexVariableFinder(
-                include_reduction_inames=False)(expr.child)
-
-        parent_inames = insn.all_inames() | insn.reduction_inames()
-        assert dependencies <= parent_inames
-
-        for iname in parent_inames:
-            if iname in duplicate_inames:
-                tag = dup_iname_to_tag[iname]
-            else:
-                tag = kernel.iname_to_tag.get(iname)
-
-            if isinstance(tag, LocalIndexTagBase):
-                kind = "l"
-            elif isinstance(tag, GroupIndexTag):
-                kind = "g"
-            elif isinstance(tag, IlpTag):
-                kind = "i"
-            else:
-                kind = "o"
-
-            if iname not in duplicate_inames and iname in dependencies:
-                if (
-                        (target_var_is_local and kind in "li")
-                        or
-                        (not target_var_is_local and kind in "i")):
-                    raise RuntimeError(
-                            "When realizing CSE with tag '%s', encountered iname "
-                            "'%s' which is depended upon by the CSE and tagged "
-                            "'%s', but not duplicated. The CSE would "
-                            "inherit this iname, which would lead to a write race. "
-                            "A likely solution of this problem is to also duplicate this "
-                            "iname."
-                            % (expr.prefix, iname, tag))
-
-            if iname in duplicate_inames and kind == "g":
-                raise RuntimeError("duplicating the iname '%s' into "
-                        "group index axes is not helpful, as they cannot "
-                        "collaborate in computing a local variable"
-                        %iname)
-
-            if iname in dependencies:
-                if not target_var_is_local and iname in duplicate_inames and kind == "l":
-                    raise RuntimeError("invalid: hardware-parallelized "
-                            "fetch into private variable")
-
-                # otherwise: all happy
-                continue
-
-            # the iname is *not* a dependency of the fetch expression
-            if iname in duplicate_inames:
-                raise RuntimeError("duplicating an iname ('%s') "
-                        "that the CSE ('%s') does not depend on "
-                        "does not make sense" % (iname, expr.child))
-
-            # Which iname dependencies are carried over from CSE host
-            # to the CSE compute instruction?
-
-            if not target_var_is_local:
-                # If we're writing to a private variable, then each
-                # hardware-parallel iname must execute its own copy of
-                # the CSE compute instruction. After all, each work item
-                # has its own set of private variables.
-
-                force_dependency = kind in "gl"
-            else:
-                # If we're writing to a local variable, then all other local
-                # dimensions see our updates, and thus they do *not* need to
-                # execute their own copy of this instruction.
-
-                force_dependency = kind == "g"
-
-            if force_dependency:
-                forced_iname_deps.add(iname)
-
-        # }}}
-
-        # {{{ concoct new inner and outer expressions
-
-        from pymbolic import var
-        assignee = var(target_var_name)
-        new_outer_expr = assignee
-
-        if duplicate_inames:
-            assignee = assignee[tuple(
-                var(iname) for iname in new_inames
-                )]
-            new_outer_expr = new_outer_expr[tuple(
-                var(iname) for iname in duplicate_inames
-                )]
-
-        from loopy.symbolic import SubstitutionMapper
-        from pymbolic.mapper.substitutor import make_subst_func
-        subst_map = SubstitutionMapper(make_subst_func(
-            dict(
-                (old_iname, var(new_iname))
-                for old_iname, new_iname in zip(duplicate_inames, new_inames))))
-        new_inner_expr = subst_map(rec(expr.child))
-
-        # }}}
-
-        from loopy.kernel import Instruction
-        new_insn = Instruction(
-                id=kernel.make_unique_instruction_id(based_on=cse_tag),
-                assignee=assignee,
-                expression=new_inner_expr,
-                forced_iname_deps=forced_iname_deps)
-
-        cse_result_insns.append(new_insn)
-
-        return new_outer_expr
-
-    from loopy.symbolic import CSECallbackMapper
-    cse_cb_mapper = CSECallbackMapper(map_cse)
-
-    new_insns = []
-    for insn in kernel.instructions:
-        was_empty = not bool(cse_result_insns)
-        new_expr = cse_cb_mapper(insn.expression)
-
-        if was_empty and cse_result_insns:
-            new_insns.append(insn.copy(expression=new_expr))
-        else:
-            new_insns.append(insn)
-
-    new_insns.extend(cse_result_insns)
-
-    # {{{ build new domain, duplicating each constraint on duplicated inames
-
-    from loopy.isl_helpers import duplicate_axes
-    new_domain = duplicate_axes(kernel.domain, duplicate_inames, new_inames)
-
-    # }}}
-
-    # {{{ set up data for temp variable
-
-
-    from loopy.kernel import (TemporaryVariable,
-            find_var_base_indices_and_shape_from_inames)
-
-    target_var_base_indices, target_var_shape = \
-            find_var_base_indices_and_shape_from_inames(
-                    new_domain, new_inames)
-
-    new_temporary_variables = kernel.temporary_variables.copy()
-    new_temporary_variables[target_var_name] = TemporaryVariable(
-            name=target_var_name,
-            dtype=dtype,
-            base_indices=target_var_base_indices,
-            shape=target_var_shape,
-            is_local=target_var_is_local)
-
-    # }}}
-
-    new_iname_to_tag = kernel.iname_to_tag.copy()
-    for old_iname, new_iname in zip(duplicate_inames, new_inames):
-        new_iname_to_tag[new_iname] = dup_iname_to_tag[old_iname]
-
-    return kernel.copy(
-            domain=new_domain,
-            instructions=new_insns,
-            temporary_variables=new_temporary_variables,
-            iname_to_tag=new_iname_to_tag)
-
-
-
-
-
-# {{{ convenience
+# {{{ convenience: add_prefetch
 
 def add_prefetch(kernel, var_name, fetch_dims=[], new_inames=None):
     used_cse_tags = set()
diff --git a/loopy/cse.py b/loopy/cse.py
new file mode 100644
index 000000000..d169d0d5b
--- /dev/null
+++ b/loopy/cse.py
@@ -0,0 +1,584 @@
+from __future__ import division
+
+import islpy as isl
+from islpy import dim_type
+from loopy.kernel import AutoFitLocalIndexTag
+import numpy as np
+
+from pytools import Record
+from pymbolic import var
+
+
+
+
+def should_cse_force_iname_dep(iname, duplicate_inames, tag, dependencies,
+        target_var_is_local, cse):
+    from loopy.kernel import (LocalIndexTagBase, GroupIndexTag, IlpTag)
+
+    if isinstance(tag, LocalIndexTagBase):
+        kind = "l"
+    elif isinstance(tag, GroupIndexTag):
+        kind = "g"
+    elif isinstance(tag, IlpTag):
+        kind = "i"
+    else:
+        kind = "o"
+
+    if iname not in duplicate_inames and iname in dependencies:
+        if (
+                (target_var_is_local and kind in "li")
+                or
+                (not target_var_is_local and kind in "i")):
+            raise RuntimeError(
+                    "When realizing CSE with tag '%s', encountered iname "
+                    "'%s' which is depended upon by the CSE and tagged "
+                    "'%s', but not duplicated. The CSE would "
+                    "inherit this iname, which would lead to a write race. "
+                    "A likely solution of this problem is to also duplicate this "
+                    "iname."
+                    % (cse.prefix, iname, tag))
+
+    if iname in duplicate_inames and kind == "g":
+        raise RuntimeError("duplicating the iname '%s' into "
+                "group index axes is not helpful, as they cannot "
+                "collaborate in computing a local variable"
+                %iname)
+
+    if iname in dependencies:
+        if not target_var_is_local and iname in duplicate_inames and kind == "l":
+            raise RuntimeError("invalid: hardware-parallelized "
+                    "fetch into private variable")
+
+        return False
+
+    # the iname is *not* a dependency of the fetch expression
+    if iname in duplicate_inames:
+        raise RuntimeError("duplicating an iname ('%s') "
+                "that the CSE ('%s') does not depend on "
+                "does not make sense" % (iname, cse.child))
+
+    # Which iname dependencies are carried over from CSE host
+    # to the CSE compute instruction?
+
+    if not target_var_is_local:
+        # If we're writing to a private variable, then each
+        # hardware-parallel iname must execute its own copy of
+        # the CSE compute instruction. After all, each work item
+        # has its own set of private variables.
+
+        return kind in "gl"
+    else:
+        # If we're writing to a local variable, then all other local
+        # dimensions see our updates, and thus they do *not* need to
+        # execute their own copy of this instruction.
+
+        return kind == "g"
+
+
+
+
+class CSEDescriptor(Record):
+    __slots__ = ["insn", "cse", "independent_inames",
+            "lead_index_exprs"]
+
+
+
+
+def to_parameters_or_project_out(param_inames, set_inames, set):
+    for iname in set.get_space().get_var_dict().keys():
+        if iname in param_inames:
+            dt, idx = set.get_space().get_var_dict()[iname]
+            set = set.move_dims(
+                    dim_type.param, set.dim(dim_type.param),
+                    dt, idx, 1)
+        elif iname in set_inames:
+            pass
+        else:
+            dt, idx = set.get_space().get_var_dict()[iname]
+            set = set.project_out(dt, idx, 1)
+
+    return set
+
+
+
+
+def solve_affine_equations_for_lhs(targets, equations, parameters):
+    # Not a very good solver: The desired variable must already
+    # occur with a coefficient of 1 on the lhs, and with no other
+    # targets on that lhs.
+
+    from loopy.symbolic import CoefficientCollector
+    coeff_coll = CoefficientCollector()
+
+    target_values = {}
+
+    for lhs, rhs in equations:
+        lhs_coeffs = coeff_coll(lhs)
+        rhs_coeffs = coeff_coll(rhs)
+
+        def shift_to_rhs(key):
+            rhs_coeffs[key] = rhs_coeffs.get(key, 0) - lhs_coeffs[key]
+            del lhs_coeffs[key]
+
+        for key in list(lhs_coeffs.iterkeys()):
+            if key in targets:
+                continue
+            elif key in parameters or key == 1:
+                shift_to_rhs(key)
+            else:
+                raise RuntimeError("unexpected key")
+
+        if len(lhs_coeffs) > 1:
+            raise RuntimeError("comically unable to solve '%s = %s' "
+                    "for one of the target variables '%s'"
+                    % (lhs, rhs, ",".join(targets)))
+
+        (tgt_name, coeff), = lhs_coeffs.iteritems()
+        if coeff != 1:
+            raise RuntimeError("comically unable to solve '%s = %s' "
+                    "for one of the target variables '%s'"
+                    % (lhs, rhs, ",".join(targets)))
+
+        solution = 0
+        for key, coeff in rhs_coeffs.iteritems():
+            if key == 1:
+                solution += coeff
+            else:
+                solution += coeff*var(key)
+
+        assert tgt_name not in target_values
+        target_values[tgt_name] = solution
+
+    return [target_values[tname] for tname in targets]
+
+
+
+
+def process_cses(kernel, lead_csed, cse_descriptors):
+    from pymbolic.mapper.unifier import BidirectionalUnifier
+
+    # {{{ parameter set/dependency finding
+
+    from loopy.symbolic import DependencyMapper
+    internal_dep_mapper = DependencyMapper(composite_leaves=False)
+
+    def get_deps(expr):
+        return set(dep.name for dep in internal_dep_mapper(expr))
+
+    # Everything that is not one of the duplicate/independent inames
+    # is turned into a parameter.
+
+    lead_csed.independent_inames = set(lead_csed.independent_inames)
+    lead_deps = get_deps(lead_csed.cse.child) & kernel.all_inames()
+    params = lead_deps - set(lead_csed.independent_inames)
+
+    # }}}
+
+    lead_domain = to_parameters_or_project_out(params,
+            lead_csed.independent_inames, kernel.domain)
+    lead_space = lead_domain.get_space()
+
+    footprint = lead_domain
+
+    uni_recs = []
+    for csed in cse_descriptors:
+        # {{{ find dependencies
+
+        cse_deps = get_deps(csed.cse.child) & kernel.all_inames()
+        csed.independent_inames = cse_deps - params
+
+        # }}}
+
+        # {{{ find unifier
+
+        unif = BidirectionalUnifier(
+                lhs_mapping_candidates=lead_csed.independent_inames,
+                rhs_mapping_candidates=csed.independent_inames)
+        unifiers = unif(lead_csed.cse.child, csed.cse.child)
+        if not unifiers:
+            raise RuntimeError("Unable to unify  "
+            "CSEs '%s' and '%s'" % (lead_csed.cse.child, csed.cse.child))
+
+        # }}}
+
+        found_good_unifier = False
+
+        for unifier in unifiers:
+            # {{{ construct, check mapping
+
+            rhs_domain = to_parameters_or_project_out(
+                    params, csed.independent_inames, kernel.domain)
+            rhs_space = rhs_domain.get_space()
+
+            map_space = lead_space
+            ln = lead_space.dim(dim_type.set)
+            map_space = map_space.move_dims(dim_type.in_, 0, dim_type.set, 0, ln)
+            rn = rhs_space.dim(dim_type.set)
+            map_space = map_space.add_dims(dim_type.out, rn)
+            for i, iname in enumerate(csed.independent_inames):
+                map_space = map_space.set_dim_name(dim_type.out, i, iname+"'")
+
+            set_space = map_space.move_dims(
+                    dim_type.out, rn,
+                    dim_type.in_, 0, ln).range()
+
+            var_map = None
+
+            from loopy.symbolic import aff_from_expr, PrimeAdder
+            add_primes = PrimeAdder(csed.independent_inames)
+            for lhs, rhs in unifier.equations:
+                cns = isl.Constraint.equality_from_aff(
+                        aff_from_expr(set_space, lhs - add_primes(rhs)))
+
+                cns_map = isl.BasicMap.from_constraint(cns)
+                if var_map is None:
+                    var_map = cns_map
+                else:
+                    var_map = var_map.intersect(cns_map)
+
+            var_map = var_map.move_dims(
+                    dim_type.in_, 0,
+                    dim_type.out, rn, ln)
+
+            restr_rhs_map = (
+                    isl.Map.from_basic_map(var_map)
+                    .intersect_range(rhs_domain))
+
+            # Sanity check: If the range of the map does not recover the
+            # domain of the expression, the unifier must have been no
+            # good.
+            if restr_rhs_map.range() != rhs_domain:
+                continue
+
+            # Sanity check: Injectivity here means that unique lead indices
+            # can be found for each 
+
+            if not var_map.is_injective():
+                raise RuntimeError("In CSEs '%s' and '%s': "
+                        "cannot find lead indices uniquely"
+                        % (lead_csed.cse.child, csed.cse.child))
+
+            lead_index_set = restr_rhs_map.domain()
+
+            footprint = footprint.union(lead_index_set)
+
+            # FIXME: This restriction will be lifted in the future, and the
+            # footprint will instead be used as the lead domain.
+
+            if not lead_index_set.is_subset(lead_domain):
+                raise RuntimeError("Index range of CSE '%s' does not cover a "
+                        "subset of lead CSE '%s'"
+                        % (csed.cse.child, lead_csed.cse.child))
+
+            found_good_unifier = True
+
+            # }}}
+
+        if not found_good_unifier:
+            raise RuntimeError("No valid unifier for '%s' and '%s'"
+                    % (csed.cse.child, lead_csed.cse.child))
+
+        uni_recs.append(unifier)
+
+        # {{{ solve for lead indices
+
+        csed.lead_index_exprs = solve_affine_equations_for_lhs(
+                lead_csed.independent_inames,
+                unifier.equations, params)
+
+        # }}}
+
+    return footprint.coalesce()
+
+
+
+
+
+def make_compute_insn(kernel, lead_csed, target_var_name, target_var_is_local,
+        new_inames, ind_iname_to_tag):
+    insn = lead_csed.insn
+
+    # {{{ decide whether to force a dep
+
+    forced_iname_deps = set()
+
+    from loopy.symbolic import IndexVariableFinder
+    dependencies = IndexVariableFinder(
+            include_reduction_inames=False)(lead_csed.cse.child)
+
+    parent_inames = insn.all_inames() | insn.reduction_inames()
+    assert dependencies <= parent_inames
+
+    for iname in parent_inames:
+        if iname in lead_csed.independent_inames:
+            tag = ind_iname_to_tag[iname]
+        else:
+            tag = kernel.iname_to_tag.get(iname)
+
+        if should_cse_force_iname_dep(
+                iname, lead_csed.independent_inames, tag, dependencies,
+                target_var_is_local, lead_csed.cse):
+            forced_iname_deps.add(iname)
+
+    # }}}
+
+    assignee = var(target_var_name)
+
+    if lead_csed.independent_inames:
+        assignee = assignee[tuple(
+            var(iname) for iname in new_inames
+            )]
+
+    from loopy.symbolic import SubstitutionMapper
+    from pymbolic.mapper.substitutor import make_subst_func
+    subst_map = SubstitutionMapper(make_subst_func(
+        dict(
+            (old_iname, var(new_iname))
+            for old_iname, new_iname in zip(lead_csed.independent_inames,
+                new_inames))))
+    new_inner_expr = subst_map(lead_csed.cse.child)
+
+    insn_prefix = lead_csed.cse.prefix
+    if insn_prefix is None:
+        insn_prefix = "cse"
+    from loopy.kernel import Instruction
+    return Instruction(
+            id=kernel.make_unique_instruction_id(based_on=insn_prefix),
+            assignee=assignee,
+            expression=new_inner_expr,
+            forced_iname_deps=forced_iname_deps)
+
+
+
+
+def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
+        ind_iname_to_tag={}, new_inames=None, default_tag_class=AutoFitLocalIndexTag,
+        follow_tag=None):
+    """
+    :arg independent_inames: which inames are supposed to be separate loops
+        in the CSE. Also determines index order of temporary array.
+    """
+
+    if not set(independent_inames) <= kernel.all_inames():
+        raise ValueError("cannot make iname '%s' independent--"
+                "they don't already exist" % ",".join(
+                    set(independent_inames)-kernel.all_inames()))
+
+    # {{{ process parallel_inames and ind_iname_to_tag arguments
+
+    ind_iname_to_tag = ind_iname_to_tag.copy()
+
+    for iname in independent_inames:
+        ind_iname_to_tag.setdefault(iname, default_tag_class())
+
+    if not set(ind_iname_to_tag.iterkeys()) <= set(independent_inames):
+        raise RuntimeError("tags for non-new inames may not be passed")
+
+    # here, all information is consolidated into ind_iname_to_tag
+
+    # }}}
+
+    # {{{ process new_inames argument, think of new inames for inames to be duplicated
+
+    if new_inames is None:
+        new_inames = [None] * len(independent_inames)
+
+    if len(new_inames) != len(independent_inames):
+        raise ValueError("If given, the new_inames argument must have the "
+                "same length as independent_inames")
+
+    temp_new_inames = []
+    for old_iname, new_iname in zip(independent_inames, new_inames):
+        if new_iname is None:
+            new_iname = kernel.make_unique_var_name(old_iname, set(temp_new_inames))
+            assert new_iname != old_iname
+
+        temp_new_inames.append(new_iname)
+
+    new_inames = temp_new_inames
+
+    # }}}
+
+    from loopy.isl_helpers import duplicate_axes
+    new_domain = duplicate_axes(kernel.domain, independent_inames, new_inames)
+
+    # {{{ gather cse descriptors
+
+    eligible_tags = [cse_tag]
+    if follow_tag is not None:
+        eligible_tags.append(follow_tag)
+
+    cse_descriptors = []
+
+    def gather_cses(cse, rec):
+        if cse.prefix not in eligible_tags:
+            rec(cse.child)
+            return
+
+        cse_descriptors.append(
+                CSEDescriptor(insn=insn, cse=cse))
+        # can't nest, don't recurse
+
+    from loopy.symbolic import CSECallbackMapper
+    cse_cb_mapper = CSECallbackMapper(gather_cses)
+
+    for insn in kernel.instructions:
+        cse_cb_mapper(insn.expression)
+
+    # }}}
+
+    # {{{ find/pick the lead cse
+
+    if not cse_descriptors:
+        raise RuntimeError("no CSEs tagged '%s' found" % cse_tag)
+
+    lead_cse_indices = [i for i, csed in enumerate(cse_descriptors) 
+            if csed.cse.prefix == cse_tag]
+    if follow_tag is not None:
+        if len(lead_cse_indices) != 1:
+            raise RuntimeError("%d lead CSEs (should be exactly 1) found for tag '%s'"
+                    % (len(lead_cse_indices), cse_tag))
+
+        lead_idx, = lead_cse_indices
+    else:
+        # pick a lead CSE at random
+        lead_idx = 0
+
+    lead_csed = cse_descriptors.pop(lead_idx)
+    lead_csed.independent_inames = independent_inames
+
+    # }}}
+
+    # FIXME: Do something with the footprint
+    footprint = process_cses(kernel, lead_csed, cse_descriptors)
+
+    # {{{ set up temp variable
+
+    var_base = cse_tag
+    if var_base is None:
+        var_base = "cse"
+    target_var_name = kernel.make_unique_var_name(var_base)
+
+    from loopy.kernel import LocalIndexTagBase
+    target_var_is_local = any(
+            isinstance(tag, LocalIndexTagBase)
+            for tag in ind_iname_to_tag.itervalues())
+
+    from loopy.kernel import (TemporaryVariable,
+            find_var_base_indices_and_shape_from_inames)
+
+    target_var_base_indices, target_var_shape = \
+            find_var_base_indices_and_shape_from_inames(
+                    new_domain, new_inames)
+
+    new_temporary_variables = kernel.temporary_variables.copy()
+    new_temporary_variables[target_var_name] = TemporaryVariable(
+            name=target_var_name,
+            dtype=np.dtype(dtype),
+            base_indices=target_var_base_indices,
+            shape=target_var_shape,
+            is_local=target_var_is_local)
+
+    # }}}
+
+    compute_insn = make_compute_insn(
+            kernel, lead_csed, target_var_name, target_var_is_local,
+            new_inames, ind_iname_to_tag)
+
+    # {{{ substitute variable references into instructions
+
+    def subst_cses(cse, rec):
+        if cse is lead_csed.cse:
+            csed = lead_csed
+
+            lead_indices = [var(iname) for iname in independent_inames]
+        else:
+            for csed in cse_descriptors:
+                if cse is csed.cse:
+                    break
+
+            if cse is not csed.cse:
+                return rec(cse.child)
+
+            lead_indices = csed.lead_index_exprs
+
+        new_outer_expr = var(target_var_name)
+        if lead_indices:
+            new_outer_expr = new_outer_expr[tuple(lead_indices)]
+
+        return new_outer_expr
+        # can't nest, don't recurse
+
+    cse_cb_mapper = CSECallbackMapper(subst_cses)
+
+    new_insns = [compute_insn] + [
+            insn.copy(expression=cse_cb_mapper(insn.expression))
+            for insn in kernel.instructions]
+
+    # }}}
+
+    new_iname_to_tag = kernel.iname_to_tag.copy()
+    for old_iname, new_iname in zip(independent_inames, new_inames):
+        new_iname_to_tag[new_iname] = ind_iname_to_tag[old_iname]
+
+    return kernel.copy(
+            domain=new_domain,
+            instructions=new_insns,
+            temporary_variables=new_temporary_variables,
+            iname_to_tag=new_iname_to_tag)
+
+
+
+
+def realize_cse_old(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=None,
+        dup_iname_to_tag={}, new_inames=None, default_tag_class=AutoFitLocalIndexTag):
+    """
+    :arg duplicate_inames: which inames are supposed to be separate loops
+        in the CSE. Also determines index order of temporary array.
+    :arg parallel_inames: only a convenient interface for dup_iname_to_tag
+    """
+
+    dtype = np.dtype(dtype)
+
+
+    cse_lookup_table = []
+    cse_result_insns = []
+
+    def map_cse(cse, rec):
+
+        # {{{ concoct new inner and outer expressions
+
+        # }}}
+
+        cse_result_insns.append(new_insn)
+        cse_lookup_table.append((cse.child, new_outer_expr))
+
+        return new_outer_expr
+
+    from loopy.symbolic import CSECallbackMapper
+    cse_cb_mapper = CSECallbackMapper(map_cse)
+
+    new_insns = []
+    for insn in kernel.instructions:
+        was_empty = not bool(cse_result_insns)
+        new_expr = cse_cb_mapper(insn.expression)
+
+        if was_empty and cse_result_insns:
+            new_insns.append(insn.copy(expression=new_expr))
+        else:
+            new_insns.append(insn)
+
+    new_insns.extend(cse_result_insns)
+
+    # build new domain, duplicating each constraint on duplicated inames
+
+
+    new_iname_to_tag = kernel.iname_to_tag.copy()
+    for old_iname, new_iname in zip(duplicate_inames, new_inames):
+        new_iname_to_tag[new_iname] = dup_iname_to_tag[old_iname]
+
+
+
+
+
+# vim: foldmethod=marker
diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index a7594e393..ba79878e9 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -202,7 +202,7 @@ def duplicate_axes(isl_obj, duplicate_inames, new_inames):
 
     iname_to_dim = more_dims.get_space().get_var_dict()
 
-    moved_dims = isl_obj
+    moved_dims = isl_obj.copy()
 
     for old_iname, new_iname in zip(duplicate_inames, new_inames):
         old_dt, old_idx = iname_to_dim[old_iname]
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 25d1ca77d..ad27038e6 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -773,7 +773,10 @@ class LoopKernel(Record):
             sorted_axes = sorted(size_dict.iterkeys())
 
             while sorted_axes or forced_sizes:
-                cur_axis = sorted_axes.pop(0)
+                if sorted_axes:
+                    cur_axis = sorted_axes.pop(0)
+                else:
+                    cur_axis = None
 
                 if len(size_list) in forced_sizes:
                     size_list.append(
@@ -782,6 +785,8 @@ class LoopKernel(Record):
                                 + forced_sizes.pop(len(size_list))))
                     continue
 
+                assert cur_axis is not None
+
                 while cur_axis > len(size_list):
                     from loopy import LoopyAdvisory
                     from warnings import warn
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index e2f48eed9..f41e1b5e3 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -93,11 +93,16 @@ class FunctionToPrimitiveMapper(IdentityMapper):
         from pymbolic.primitives import Variable
         if isinstance(expr.function, Variable) and expr.function.name == "cse":
             from pymbolic.primitives import CommonSubexpression
-            if len(expr.parameters) == 2:
-                if not isinstance(expr.parameters[1], Variable):
-                    raise TypeError("second argument to cse() must be a symbol")
+            if len(expr.parameters) in [1, 2]:
+                if len(expr.parameters) == 2:
+                    if not isinstance(expr.parameters[1], Variable):
+                        raise TypeError("second argument to cse() must be a symbol")
+                    tag = expr.parameters[1].name
+                else:
+                    tag = None
+
                 return CommonSubexpression(
-                        self.rec(expr.parameters[0]), expr.parameters[1].name)
+                        self.rec(expr.parameters[0]), tag)
             else:
                 raise TypeError("cse takes two arguments")
 
@@ -376,7 +381,7 @@ class LoopyCCodeMapper(CCodeMapper):
 
 # }}}
 
-# {{{ aff -> expr conversion
+# {{{ aff <-> expr conversion
 
 def aff_to_expr(aff, except_name=None, error_on_name=None):
     if except_name is not None and error_on_name is not None:
@@ -551,6 +556,20 @@ class VariableFetchCSEMapper(IdentityMapper):
 
 # }}}
 
+# {{{ prime-adder
+
+class PrimeAdder(IdentityMapper):
+    def __init__(self, which_vars):
+        self.which_vars = which_vars
+
+    def map_variable(self, expr):
+        from pymbolic import var
+        if expr.name in self.which_vars:
+            return var(expr.name+"'")
+        else:
+            return expr
+
+# }}}
 
 
 
diff --git a/test/test_loopy.py b/test/test_loopy.py
index f5eef7318..2e01144a5 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2,15 +2,19 @@ from __future__ import division
 
 import numpy as np
 import loopy as lp
+import pyopencl as cl
 
 from pyopencl.tools import pytest_generate_tests_for_pyopencl \
         as pytest_generate_tests
 
+__all__ = ["pytest_generate_tests",
+    "cl" # 'cl.create_some_context'
+    ]
+
 
 
 
 def test_owed_barriers(ctx_factory):
-    dtype = np.float32
     ctx = ctx_factory()
 
     knl = lp.make_kernel(ctx.devices[0],
@@ -47,7 +51,7 @@ def test_wg_too_small(ctx_factory):
 
     for gen_knl in kernel_gen:
         try:
-            compiled = lp.CompiledKernel(ctx, gen_knl)
+            lp.CompiledKernel(ctx, gen_knl)
         except RuntimeError, e:
             assert "implemented and desired" in str(e)
             pass # expected!
@@ -57,6 +61,57 @@ def test_wg_too_small(ctx_factory):
 
 
 
+def test_multi_cse(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i]: 0<=i<100}",
+            [
+                "[i] <float32> z[i] = cse(a[i]) + cse(a[i])**2"
+                ],
+            [lp.ArrayArg("a", np.float32, shape=(100,))],
+            local_sizes={0: 16})
+
+    knl = lp.split_dimension(knl, "i", 16, inner_tag="l.0")
+    knl = lp.realize_cse(knl, None, np.float32, ["i_inner"])
+
+    kernel_gen = lp.generate_loop_schedules(knl)
+    kernel_gen = lp.check_kernels(kernel_gen)
+
+    for gen_knl in kernel_gen:
+        compiled = lp.CompiledKernel(ctx, gen_knl)
+        print compiled.code
+
+
+
+
+def test_stencil(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i,j]: 0<= i,j <4}",
+            [
+                "[i] <float32> z[i,j] = -2*cse(a[i,j])"
+                    " + cse(a[i,j-1])"
+                    " + cse(a[i,j+1])"
+                    " + cse(a[i-1,j])"
+                    " + cse(a[i+1,i])" # watch out: i!
+                ],
+            [lp.ArrayArg("a", np.float32, shape=(4,4,))])
+
+    knl = lp.split_dimension(knl, "i", 16)
+
+    try:
+        lp.realize_cse(knl, None, np.float32, ["i_inner", "j"])
+    except RuntimeError, e:
+        assert "does not cover a subset" in str(e)
+        pass # expected!
+    else:
+        assert False # expecting an error
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab