From 27f410947cecf32ccae4328a516346fd3d556667 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 26 Nov 2014 21:35:04 -0600
Subject: [PATCH] Implement affine_map_inames

---
 doc/Makefile       |   2 +-
 doc/reference.rst  |   2 +
 loopy/__init__.py  | 184 +++++++++++++++++++++++++++++++++++++++++++++
 test/test_loopy.py |  12 +++
 4 files changed, 199 insertions(+), 1 deletion(-)

diff --git a/doc/Makefile b/doc/Makefile
index 66a7415df..e0cde0808 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -3,7 +3,7 @@
 
 # You can set these variables from the command line.
 SPHINXOPTS    =
-SPHINXBUILD   = sphinx-build
+SPHINXBUILD   = python ` which sphinx-build`
 PAPER         =
 BUILDDIR      = _build
 
diff --git a/doc/reference.rst b/doc/reference.rst
index cffed51f7..c829462d3 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -368,6 +368,8 @@ Wrangling inames
 
 .. autofunction:: split_reduction_outward
 
+.. autofunction:: affine_map_inames
+
 Dealing with Parameters
 ^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index b79eeeee6..bb397db32 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1480,4 +1480,188 @@ def set_argument_order(kernel, arg_names):
 # }}}
 
 
+# {{{ affine map inames
+
+def affine_map_inames(kernel, old_inames, new_inames, equations):
+    """Return a new *kernel* where the affine transform
+    specified by *equations* has been applied to the inames.
+
+    :arg old_inames: A list of inames to be replaced by affine transforms
+        of their values.
+        May also be a string of comma-separated inames.
+
+    :arg new_inames: A list of new inames that are not yet used in *kernel*,
+        but have their values established in terms of *old_inames* by
+        *equations*.
+        May also be a string of comma-separated inames.
+    :arg equations: A list of equations estabilishing a relationship
+        between *old_inames* and *new_inames*. Each equation may be
+        a tuple ``(lhs, rhs)`` of expressions or a string, with left and
+        right hand side of the equation separated by ``=``.
+    """
+
+    # {{{ check and parse arguments
+
+    if isinstance(new_inames, str):
+        new_inames = new_inames.split(",")
+        new_inames = [iname.strip() for iname in new_inames]
+    if isinstance(old_inames, str):
+        old_inames = old_inames.split(",")
+        old_inames = [iname.strip() for iname in old_inames]
+    if isinstance(equations, str):
+        equations = [equations]
+
+    import re
+    EQN_RE = re.compile(r"^([^=]+)=([^=]+)$")
+
+    def parse_equation(eqn):
+        if isinstance(eqn, str):
+            eqn_match = EQN_RE.match(eqn)
+            if not eqn_match:
+                raise ValueError("invalid equation: %s" % eqn)
+
+            from loopy.symbolic import parse
+            lhs = parse(eqn_match.group(1))
+            rhs = parse(eqn_match.group(2))
+            return (lhs, rhs)
+        elif isinstance(eqn, tuple):
+            if len(eqn) != 2:
+                raise ValueError("unexpected length of equation tuple, "
+                        "got %d, should be 2" % len(eqn))
+            return eqn
+        else:
+            raise ValueError("unexpected type of equation"
+                    "got %d, should be string or tuple"
+                    % type(eqn).__name__)
+
+    equations = [parse_equation(eqn) for eqn in equations]
+
+    all_vars = kernel.all_variable_names()
+    for iname in new_inames:
+        if iname in all_vars:
+            raise LoopyError("new iname '%s' is already used in kernel"
+                    % iname)
+
+    for iname in old_inames:
+        if iname not in kernel.all_inames():
+            raise LoopyError("old iname '%s' not known" % iname)
+
+    # }}}
+
+    # {{{ substitute iname use
+
+    from pymbolic.algorithm import solve_affine_equations_for
+    old_inames_to_expr = solve_affine_equations_for(old_inames, equations)
+
+    subst_dict = dict(
+            (v.name, expr)
+            for v, expr in old_inames_to_expr.items())
+
+    var_name_gen = kernel.get_var_name_generator()
+
+    from pymbolic.mapper.substitutor import make_subst_func
+    old_to_new = ExpandingSubstitutionMapper(kernel.substitutions, var_name_gen,
+            make_subst_func(subst_dict), within=lambda stack: True)
+
+    kernel = (old_to_new.map_kernel(kernel)
+            .copy(
+                applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict]
+                ))
+
+    # }}}
+
+    # {{{ change domains
+
+    new_inames_set = set(new_inames)
+    old_inames_set = set(old_inames)
+
+    new_domains = []
+    for idom, dom in enumerate(kernel.domains):
+        dom_var_dict = dom.get_var_dict()
+        old_iname_overlap = [
+                iname
+                for iname in old_inames
+                if iname in dom_var_dict]
+
+        if not old_iname_overlap:
+            new_domains.append(dom)
+            continue
+
+        from loopy.symbolic import get_dependencies
+        dom_new_inames = set()
+        dom_old_inames = set()
+
+        # mapping for new inames to dim_types
+        new_iname_dim_types = {}
+
+        dom_equations = []
+        for iname in old_iname_overlap:
+            for ieqn, (lhs, rhs) in enumerate(equations):
+                eqn_deps = get_dependencies(lhs) | get_dependencies(rhs)
+                if iname in eqn_deps:
+                    dom_new_inames.update(eqn_deps & new_inames_set)
+                    dom_old_inames.update(eqn_deps & old_inames_set)
+
+                if dom_old_inames:
+                    dom_equations.append((lhs, rhs))
+
+                this_eqn_old_iname_dim_types = set(
+                        dom_var_dict[old_iname][0]
+                        for old_iname in eqn_deps & old_inames_set)
+
+                if this_eqn_old_iname_dim_types:
+                    if len(this_eqn_old_iname_dim_types) > 1:
+                        raise ValueError("inames '%s' (from equation %d (0-based)) "
+                                "in domain %d (0-based) are not "
+                                "of a uniform dim_type"
+                                % (", ".join(eqn_deps & old_inames_set), ieqn, idom))
+
+                    this_eqn_new_iname_dim_type, = this_eqn_old_iname_dim_types
+
+                    for new_iname in eqn_deps & new_inames_set:
+                        if new_iname in new_iname_dim_types:
+                            if (this_eqn_new_iname_dim_type
+                                    != new_iname_dim_types[new_iname]):
+                                raise ValueError("dim_type disagreement for "
+                                        "iname '%s' (from equation %d (0-based)) "
+                                        "in domain %d (0-based)"
+                                        % (new_iname, ieqn, idom))
+                        else:
+                            new_iname_dim_types[new_iname] = \
+                                    this_eqn_new_iname_dim_type
+
+        if not dom_old_inames <= set(dom_var_dict):
+            raise ValueError("domain %d (0-based) does not know about "
+                    "all old inames (specifically '%s') needed to define new inames"
+                    % (idom, ", ".join(dom_old_inames - set(dom_var_dict))))
+
+        # add inames to domain with correct dim_types
+        dom_new_inames = list(dom_new_inames)
+        for iname in dom_new_inames:
+            dt = new_iname_dim_types[iname]
+            iname_idx = dom.dim(dt)
+            dom = dom.add_dims(dt, 1)
+            dom = dom.set_dim_name(dt, iname_idx, iname)
+
+        # add equations
+        from loopy.symbolic import aff_from_expr
+        for lhs, rhs in dom_equations:
+            dom = dom.add_constraint(
+                    isl.Constraint.equality_from_aff(
+                        aff_from_expr(dom.space, rhs - lhs)))
+
+        # project out old inames
+        for iname in dom_old_inames:
+            dt, idx = dom.get_var_dict()[iname]
+            dom = dom.project_out(dt, idx, 1)
+
+        new_domains.append(dom)
+
+    # }}}
+
+    return kernel.copy(domains=new_domains)
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 6dcaa2dd7..5918991e8 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1751,6 +1751,18 @@ def test_set_arg_order():
     knl = lp.set_argument_order(knl, "out,a,n,b")
 
 
+def test_affine_map_inames():
+    knl = lp.make_kernel(
+        "{[e, i,j,n]: 0<=e<E and 0<=i,j,n<N}",
+        "rhsQ[e, n+i, j] = rhsQ[e, n+i, j] - D[i, n]*x[i,j]")
+
+    knl = lp.affine_map_inames(knl,
+            "i", "i0",
+            "i0 = n+i")
+
+    print(knl)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab