From 27f410947cecf32ccae4328a516346fd3d556667 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 26 Nov 2014 21:35:04 -0600 Subject: [PATCH] Implement affine_map_inames --- doc/Makefile | 2 +- doc/reference.rst | 2 + loopy/__init__.py | 184 +++++++++++++++++++++++++++++++++++++++++++++ test/test_loopy.py | 12 +++ 4 files changed, 199 insertions(+), 1 deletion(-) diff --git a/doc/Makefile b/doc/Makefile index 66a7415df..e0cde0808 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = -SPHINXBUILD = sphinx-build +SPHINXBUILD = python ` which sphinx-build` PAPER = BUILDDIR = _build diff --git a/doc/reference.rst b/doc/reference.rst index cffed51f7..c829462d3 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -368,6 +368,8 @@ Wrangling inames .. autofunction:: split_reduction_outward +.. autofunction:: affine_map_inames + Dealing with Parameters ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index b79eeeee6..bb397db32 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1480,4 +1480,188 @@ def set_argument_order(kernel, arg_names): # }}} +# {{{ affine map inames + +def affine_map_inames(kernel, old_inames, new_inames, equations): + """Return a new *kernel* where the affine transform + specified by *equations* has been applied to the inames. + + :arg old_inames: A list of inames to be replaced by affine transforms + of their values. + May also be a string of comma-separated inames. + + :arg new_inames: A list of new inames that are not yet used in *kernel*, + but have their values established in terms of *old_inames* by + *equations*. + May also be a string of comma-separated inames. + :arg equations: A list of equations estabilishing a relationship + between *old_inames* and *new_inames*. Each equation may be + a tuple ``(lhs, rhs)`` of expressions or a string, with left and + right hand side of the equation separated by ``=``. + """ + + # {{{ check and parse arguments + + if isinstance(new_inames, str): + new_inames = new_inames.split(",") + new_inames = [iname.strip() for iname in new_inames] + if isinstance(old_inames, str): + old_inames = old_inames.split(",") + old_inames = [iname.strip() for iname in old_inames] + if isinstance(equations, str): + equations = [equations] + + import re + EQN_RE = re.compile(r"^([^=]+)=([^=]+)$") + + def parse_equation(eqn): + if isinstance(eqn, str): + eqn_match = EQN_RE.match(eqn) + if not eqn_match: + raise ValueError("invalid equation: %s" % eqn) + + from loopy.symbolic import parse + lhs = parse(eqn_match.group(1)) + rhs = parse(eqn_match.group(2)) + return (lhs, rhs) + elif isinstance(eqn, tuple): + if len(eqn) != 2: + raise ValueError("unexpected length of equation tuple, " + "got %d, should be 2" % len(eqn)) + return eqn + else: + raise ValueError("unexpected type of equation" + "got %d, should be string or tuple" + % type(eqn).__name__) + + equations = [parse_equation(eqn) for eqn in equations] + + all_vars = kernel.all_variable_names() + for iname in new_inames: + if iname in all_vars: + raise LoopyError("new iname '%s' is already used in kernel" + % iname) + + for iname in old_inames: + if iname not in kernel.all_inames(): + raise LoopyError("old iname '%s' not known" % iname) + + # }}} + + # {{{ substitute iname use + + from pymbolic.algorithm import solve_affine_equations_for + old_inames_to_expr = solve_affine_equations_for(old_inames, equations) + + subst_dict = dict( + (v.name, expr) + for v, expr in old_inames_to_expr.items()) + + var_name_gen = kernel.get_var_name_generator() + + from pymbolic.mapper.substitutor import make_subst_func + old_to_new = ExpandingSubstitutionMapper(kernel.substitutions, var_name_gen, + make_subst_func(subst_dict), within=lambda stack: True) + + kernel = (old_to_new.map_kernel(kernel) + .copy( + applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] + )) + + # }}} + + # {{{ change domains + + new_inames_set = set(new_inames) + old_inames_set = set(old_inames) + + new_domains = [] + for idom, dom in enumerate(kernel.domains): + dom_var_dict = dom.get_var_dict() + old_iname_overlap = [ + iname + for iname in old_inames + if iname in dom_var_dict] + + if not old_iname_overlap: + new_domains.append(dom) + continue + + from loopy.symbolic import get_dependencies + dom_new_inames = set() + dom_old_inames = set() + + # mapping for new inames to dim_types + new_iname_dim_types = {} + + dom_equations = [] + for iname in old_iname_overlap: + for ieqn, (lhs, rhs) in enumerate(equations): + eqn_deps = get_dependencies(lhs) | get_dependencies(rhs) + if iname in eqn_deps: + dom_new_inames.update(eqn_deps & new_inames_set) + dom_old_inames.update(eqn_deps & old_inames_set) + + if dom_old_inames: + dom_equations.append((lhs, rhs)) + + this_eqn_old_iname_dim_types = set( + dom_var_dict[old_iname][0] + for old_iname in eqn_deps & old_inames_set) + + if this_eqn_old_iname_dim_types: + if len(this_eqn_old_iname_dim_types) > 1: + raise ValueError("inames '%s' (from equation %d (0-based)) " + "in domain %d (0-based) are not " + "of a uniform dim_type" + % (", ".join(eqn_deps & old_inames_set), ieqn, idom)) + + this_eqn_new_iname_dim_type, = this_eqn_old_iname_dim_types + + for new_iname in eqn_deps & new_inames_set: + if new_iname in new_iname_dim_types: + if (this_eqn_new_iname_dim_type + != new_iname_dim_types[new_iname]): + raise ValueError("dim_type disagreement for " + "iname '%s' (from equation %d (0-based)) " + "in domain %d (0-based)" + % (new_iname, ieqn, idom)) + else: + new_iname_dim_types[new_iname] = \ + this_eqn_new_iname_dim_type + + if not dom_old_inames <= set(dom_var_dict): + raise ValueError("domain %d (0-based) does not know about " + "all old inames (specifically '%s') needed to define new inames" + % (idom, ", ".join(dom_old_inames - set(dom_var_dict)))) + + # add inames to domain with correct dim_types + dom_new_inames = list(dom_new_inames) + for iname in dom_new_inames: + dt = new_iname_dim_types[iname] + iname_idx = dom.dim(dt) + dom = dom.add_dims(dt, 1) + dom = dom.set_dim_name(dt, iname_idx, iname) + + # add equations + from loopy.symbolic import aff_from_expr + for lhs, rhs in dom_equations: + dom = dom.add_constraint( + isl.Constraint.equality_from_aff( + aff_from_expr(dom.space, rhs - lhs))) + + # project out old inames + for iname in dom_old_inames: + dt, idx = dom.get_var_dict()[iname] + dom = dom.project_out(dt, idx, 1) + + new_domains.append(dom) + + # }}} + + return kernel.copy(domains=new_domains) + +# }}} + + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index 6dcaa2dd7..5918991e8 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1751,6 +1751,18 @@ def test_set_arg_order(): knl = lp.set_argument_order(knl, "out,a,n,b") +def test_affine_map_inames(): + knl = lp.make_kernel( + "{[e, i,j,n]: 0<=e<E and 0<=i,j,n<N}", + "rhsQ[e, n+i, j] = rhsQ[e, n+i, j] - D[i, n]*x[i,j]") + + knl = lp.affine_map_inames(knl, + "i", "i0", + "i0 = n+i") + + print(knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab