From 6f48d80bb2e90aabde48a1b8e006a0220597f843 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 10 Aug 2011 02:48:26 +0200
Subject: [PATCH] Clean up loop bound handling. Don't remove div constraints
 wholesale.

---
 examples/matrix-mul.py    |   2 +-
 loopy/__init__.py         |   2 +
 loopy/codegen/__init__.py |  13 ++-
 loopy/codegen/bounds.py   | 197 +++++++++++++++++++++++++++++++++-----
 loopy/codegen/loop_dim.py |  49 ++++++----
 loopy/codegen/prefetch.py |  24 ++---
 loopy/compiled.py         |   6 +-
 loopy/isl.py              | 183 +----------------------------------
 loopy/kernel.py           |  31 +++---
 loopy/prefetch.py         |  50 ++++++++--
 loopy/symbolic.py         |  56 ++++++++++-
 test/test_matmul.py       |  10 +-
 12 files changed, 356 insertions(+), 267 deletions(-)

diff --git a/examples/matrix-mul.py b/examples/matrix-mul.py
index 178a9a291..82c8c4e33 100644
--- a/examples/matrix-mul.py
+++ b/examples/matrix-mul.py
@@ -56,7 +56,7 @@ def image_matrix_mul_ilp(ctx_factory=cl.create_some_context):
 
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems({})[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 417957584..2aadba044 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -34,6 +34,8 @@ register_mpz_with_pymbolic()
 
 # TODO: implement efficient div_ceil?
 # TODO: why are corner cases inefficient?
+# TODO: Use gists
+# TODO: Imitate codegen bulk slab handling in bulk slab trials
 
 
 
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 4048ca971..d0e4cf4a3 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -240,13 +240,14 @@ def generate_code(kernel):
             (TAG_GROUP_IDX, "get_group_id"),
             (TAG_WORK_ITEM_IDX, "get_local_id")]:
         for iname in kernel.ordered_inames_by_tag_type(what_cls):
-            start, stop = kernel.get_projected_bounds(iname)
+            lower, upper, equality = kernel.get_bounds(iname, (iname,), allow_parameters=True)
+            assert not equality
             mod.append(Define(iname, "(%s + (int) %s(%d)) /* [%s, %s) */"
-                        % (ccm(start),
+                        % (ccm(lower),
                             func,
                             kernel.iname_to_tag[iname].axis,
-                            ccm(start),
-                            ccm(stop))))
+                            ccm(lower),
+                            ccm(upper))))
 
     mod.append(Line())
 
@@ -277,7 +278,9 @@ def generate_code(kernel):
         FunctionBody(
             CLRequiredWorkGroupSize(
                 tuple(dim_length
-                    for dim_length in kernel.tag_type_lengths(TAG_WORK_ITEM_IDX)),
+                    for dim_length in kernel.tag_type_lengths(
+                        TAG_WORK_ITEM_IDX,
+                        allow_parameters=False)),
                 CLKernel(FunctionDeclaration(
                     Value("void", kernel.name), args))),
             body))
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index cacea9f15..e602e951e 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -2,10 +2,129 @@ from __future__ import division
 
 import islpy as isl
 from islpy import dim_type
+from pymbolic.mapper.stringifier import PREC_NONE
 
 
 
 
+# {{{ find bounds from set
+
+def get_bounds_constraints(set, iname, admissible_inames, allow_parameters):
+    if admissible_inames is not None or not allow_parameters:
+        if admissible_inames is None:
+            proj_type = []
+        else:
+            assert iname in admissible_inames
+            proj_type = [dim_type.set]
+
+        if not allow_parameters:
+            proj_type.append(dim_type.param)
+
+        set = (set
+                .project_out_except(admissible_inames, proj_type)
+                .compute_divs()
+                .remove_divs_of_dim_type(dim_type.set))
+
+    basic_sets = set.get_basic_sets()
+    if len(basic_sets) > 1:
+        set = set.coalesce()
+        basic_sets = set.get_basic_sets()
+        if len(basic_sets) > 1:
+            raise RuntimeError("got non-convex set in bounds generation")
+
+    bset, = basic_sets
+
+    # FIXME perhaps use some form of hull here if there's more than one
+    # basic set?
+
+    lower = []
+    upper = []
+    equality = []
+
+    space = bset.get_dim()
+
+    var_dict = space.get_var_dict()
+    iname_tp, iname_idx = var_dict[iname]
+
+    for cns in bset.get_constraints():
+        assert not cns.is_div_constraint()
+
+        iname_coeff = int(cns.get_coefficient(iname_tp, iname_idx))
+
+        if iname_coeff == 0:
+            continue
+
+        if cns.is_equality():
+            equality.append(cns)
+        elif iname_coeff < 0:
+            upper.append(cns)
+        else: #  iname_coeff > 0
+            lower.append(cns)
+
+    return lower, upper, equality
+
+def solve_constraint_for_bound(cns, iname):
+    from loopy.symbolic import constraint_to_expr
+    rhs, iname_coeff = constraint_to_expr(cns, except_name=iname)
+
+    if iname_coeff == 0:
+        raise ValueError("cannot solve constraint for '%s'--"
+                "constraint does not contain variable"
+                % iname)
+
+    from pymbolic import expand
+    from pytools import div_ceil
+    from pymbolic import flatten
+    from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
+    cfm = CommutativeConstantFoldingMapper()
+
+    if iname_coeff > 0 or cns.is_equality():
+        if cns.is_equality():
+            kind = "=="
+        else:
+            kind = ">="
+
+        return kind, cfm(flatten(div_ceil(expand(-rhs), iname_coeff)))
+    else: # iname_coeff < 0
+        from pytools import div_ceil
+        return "<", cfm(flatten(div_ceil(rhs+1, -iname_coeff)))
+
+def get_bounds(set, iname, admissible_inames, allow_parameters):
+    """Get an overapproximation of the loop bounds for the variable *iname*,
+    as actual bounds.
+    """
+
+    lower, upper, equality = get_bounds_constraints(
+            set, iname, admissible_inames, allow_parameters)
+
+    def do_solve(cns_list, assert_kind):
+        result = []
+        for cns in cns_list:
+            kind, bound = solve_constraint_for_bound(cns, iname)
+            assert kind == assert_kind
+            result.append(bound)
+
+        return result
+
+    lower_bounds = do_solve(lower, ">=")
+    upper_bounds = do_solve(upper, "<")
+    equalities = do_solve(equality, "==")
+
+    def agg_if_more_than_one(descr, agg_func, l):
+        if len(l) == 0:
+            raise ValueError("no %s bound found for '%s'" % (descr, iname))
+        elif len(l) == 1:
+            return l[0]
+        else:
+            return agg_func(l)
+
+    from pymbolic.primitives import Min, Max
+    return (agg_if_more_than_one("lower", Max, lower_bounds),
+            agg_if_more_than_one("upper", Min, upper_bounds),
+            equalities)
+
+# }}}
+
 # {{{ bounds check generator
 
 def constraint_to_code(ccm, cns):
@@ -14,7 +133,7 @@ def constraint_to_code(ccm, cns):
     else:
         comp_op = ">="
 
-    from loopy.isl import constraint_to_expr
+    from loopy.symbolic import constraint_to_expr
     return "%s %s 0" % (ccm(constraint_to_expr(cns)), comp_op)
 
 def filter_necessary_constraints(implemented_domain, constraints):
@@ -26,11 +145,12 @@ def filter_necessary_constraints(implemented_domain, constraints):
             .add_constraint(cns))]
 
 def generate_bounds_checks(domain, check_vars, implemented_domain):
-    domain_bset, = domain.get_basic_sets()
-
-    projected_domain_bset = isl.project_out_except(
-            domain_bset, check_vars, [dim_type.set])
-    projected_domain_bset = projected_domain_bset.remove_divs()
+    projected_domain_bset, = (domain
+            .project_out_except(check_vars, [dim_type.set])
+            .compute_divs()
+            .remove_divs_of_dim_type(dim_type.set)
+            .coalesce()
+            .get_basic_sets())
 
     space = domain.get_dim()
 
@@ -71,7 +191,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
 
     cfm = CommutativeConstantFoldingMapper()
 
-    from loopy.isl import constraint_to_expr, solve_constraint_for_bound
+    from loopy.symbolic import constraint_to_expr
     from pytools import any
 
     if any(cns.is_equality() for cns in constraints):
@@ -95,25 +215,28 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
                 assert kind == ">="
                 start_exprs.append(bound)
 
-    while len(start_exprs) >= 2:
-        start_exprs.append(
-                "max(%s, %s)" % (
-                    ccm(start_exprs.pop()),
-                    ccm(start_exprs.pop())))
-
-    start_expr, = start_exprs # there has to be at least one
+    if len(start_exprs) > 1:
+        from pymbolic.primitives import Max
+        start_expr = Max(start_exprs)
+    elif len(start_exprs) == 1:
+        start_expr, = start_exprs
+    else:
+        raise RuntimeError("no starting value found for 'for' loop in '%s'"
+                % iname)
 
     from cgen import For
     from loopy.codegen import wrap_in
     return wrap_in(For,
-            "int %s = %s" % (iname, start_expr),
+            "int %s = %s" % (iname, ccm(start_expr, PREC_NONE)),
             " && ".join(end_conds),
             "++%s" % iname,
             stmt)
 
 # }}}
 
-def get_valid_check_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
+# {{{ on which variables may a conditional depend?
+
+def get_defined_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
     """
     :param exclude_tags: a tuple of tag classes to exclude
     """
@@ -123,11 +246,27 @@ def get_valid_check_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=())
         exclude_tag_classes = exclude_tag_classes + (TAG_ILP,)
 
     from loopy.schedule import ScheduledLoop
-    allowed_vars = set(
+    defined_vars = set(
             sched_item.iname
             for sched_item in kernel.schedule[:sched_index]
             if isinstance(sched_item, ScheduledLoop))
 
+    defined_vars = set(
+            iname
+            for iname in defined_vars
+            if not isinstance(
+                kernel.iname_to_tag.get(iname),
+                exclude_tag_classes))
+
+    return defined_vars
+
+def get_valid_check_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
+    """
+    :param exclude_tags: a tuple of tag classes to exclude
+    """
+
+    allowed_vars = get_defined_vars(kernel, sched_index, allow_ilp, exclude_tag_classes)
+
     from pytools import any
     from loopy.prefetch import LocalMemoryPrefetch
     all_lmem_prefetches_scheduled = not any(
@@ -143,15 +282,25 @@ def get_valid_check_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=())
         from loopy.kernel import TAG_WORK_ITEM_IDX
         allowed_vars -= set(kernel.inames_by_tag_type(TAG_WORK_ITEM_IDX))
 
-    allowed_vars = set(
-            iname
-            for iname in allowed_vars
-            if not isinstance(
-                kernel.iname_to_tag.get(iname),
-                exclude_tag_classes))
-
     return allowed_vars
 
+# }}}
+
+# {{{
+
+def pick_simple_constraint(constraints, iname):
+    if len(constraints) == 0:
+        raise RuntimeError("no constraint for '%s'" % iname)
+    elif len(constraints) == 1:
+        return constraints[0]
+
+    from pymbolic.mapper.flop_counter import FlopCounter
+    count_flops = FlopCounter()
+
+    from pytools import argmin2
+    return argmin2(
+            (cns, count_flops(solve_constraint_for_bound(cns, iname)[1]))
+            for cns in constraints)
 
 
 
diff --git a/loopy/codegen/loop_dim.py b/loopy/codegen/loop_dim.py
index 503cc6831..872aeca77 100644
--- a/loopy/codegen/loop_dim.py
+++ b/loopy/codegen/loop_dim.py
@@ -4,17 +4,36 @@ import numpy as np
 from loopy.codegen import ExecutionDomain, gen_code_block
 from pytools import Record
 import islpy as isl
+from islpy import dim_type
 from loopy.codegen.dispatch import build_loop_nest
 
 
 
 
 
+def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
+    from loopy.isl import cast_constraint_to_space
+    from loopy.codegen.bounds import get_bounds_constraints, get_defined_vars
+    lower_constraints_orig, upper_constraints_orig, equality_constraints_orig = \
+            get_bounds_constraints(kernel.domain, iname,
+                    frozenset([iname])
+                    | frozenset(get_defined_vars(kernel, sched_index+1, allow_ilp=False)),
+                    allow_parameters=True)
+
+    assert not equality_constraints_orig
+    from loopy.codegen.bounds import pick_simple_constraint
+    lb_cns_orig = pick_simple_constraint(lower_constraints_orig, iname)
+    ub_cns_orig = pick_simple_constraint(upper_constraints_orig, iname)
+
+    lb_cns_orig = cast_constraint_to_space(lb_cns_orig, kernel.space)
+    ub_cns_orig = cast_constraint_to_space(ub_cns_orig, kernel.space)
+
+    return lb_cns_orig, ub_cns_orig
+
 # {{{ conditional-minimizing slab decomposition
 
 def get_slab_decomposition(cgs, kernel, sched_index, exec_domain):
-    from loopy.isl import (cast_constraint_to_space,
-            block_shift_constraint, negate_constraint)
+    from loopy.isl import block_shift_constraint, negate_constraint
 
     ccm = cgs.c_code_mapper
     space = kernel.space
@@ -23,9 +42,8 @@ def get_slab_decomposition(cgs, kernel, sched_index, exec_domain):
 
     # {{{ attempt slab partition to reduce conditional count
 
-    lb_cns_orig, ub_cns_orig = kernel.get_projected_bounds_constraints(iname)
-    lb_cns_orig = cast_constraint_to_space(lb_cns_orig, space)
-    ub_cns_orig = cast_constraint_to_space(ub_cns_orig, space)
+    lb_cns_orig, ub_cns_orig = get_simple_loop_bounds(kernel, sched_index, iname,
+            exec_domain.implemented_domain)
 
     # jostle the constant in {lb,ub}_cns to see if we can get
     # fewer conditionals in the bulk middle segment
@@ -104,9 +122,8 @@ def get_slab_decomposition(cgs, kernel, sched_index, exec_domain):
 # {{{ unrolled/ILP loops
 
 def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
-    from loopy.isl import (
-            cast_constraint_to_space, solve_constraint_for_bound,
-            block_shift_constraint)
+    from loopy.isl import block_shift_constraint
+    from loopy.codegen.bounds import solve_constraint_for_bound
 
     from cgen import (POD, Assign, Line, Statement as S, Initializer, Const)
 
@@ -115,9 +132,8 @@ def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
     iname = kernel.schedule[sched_index].iname
     tag = kernel.iname_to_tag.get(iname)
 
-    lower_cns, upper_cns = kernel.get_projected_bounds_constraints(iname)
-    lower_cns = cast_constraint_to_space(lower_cns, space)
-    upper_cns = cast_constraint_to_space(upper_cns, space)
+    lower_cns, upper_cns = get_simple_loop_bounds(kernel, sched_index, iname,
+            exec_domain.implemented_domain)
 
     lower_kind, lower_bound = solve_constraint_for_bound(lower_cns, iname)
     upper_kind, upper_bound = solve_constraint_for_bound(upper_cns, iname)
@@ -125,10 +141,8 @@ def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
     assert lower_kind == ">="
     assert upper_kind == "<"
 
-    from pymbolic import flatten
-    from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-    cfm = CommutativeConstantFoldingMapper()
-    length = int(cfm(flatten(upper_bound-lower_bound)))
+    success, length = kernel.domain.project_out_except([iname], [dim_type.set]).count()
+    assert success == 0
 
     def generate_idx_eq_slabs():
         for i in xrange(length):
@@ -184,7 +198,6 @@ def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
 def generate_parallel_loop_dim_code(cgs, kernel, sched_index, exec_domain):
     from loopy.isl import make_slab
 
-
     ccm = cgs.c_code_mapper
     space = kernel.space
     iname = kernel.schedule[sched_index].iname
@@ -206,7 +219,7 @@ def generate_parallel_loop_dim_code(cgs, kernel, sched_index, exec_domain):
                 .add_constraint(ub_cns_orig))
     else:
         impl_len = tag.forced_length
-        start, _ = kernel.get_projected_bounds(iname)
+        start, _, _ = kernel.get_bounds(iname, (iname,), allow_parameters=True)
         exec_domain = exec_domain.intersect(
                 make_slab(kernel.space, iname, start, start+impl_len))
 
@@ -224,7 +237,7 @@ def generate_parallel_loop_dim_code(cgs, kernel, sched_index, exec_domain):
                 domain=kernel.domain.intersect(slab))
         result.append(
                 add_comment(cmt,
-                    build_loop_nest(cgs, kernel, sched_index+1,
+                    build_loop_nest(cgs, new_kernel, sched_index+1,
                         exec_domain)))
 
     from loopy.codegen import gen_code_block
diff --git a/loopy/codegen/prefetch.py b/loopy/codegen/prefetch.py
index d2b2dcdca..946f23cc9 100644
--- a/loopy/codegen/prefetch.py
+++ b/loopy/codegen/prefetch.py
@@ -125,21 +125,20 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
     pf_iname = pf.inames[pf_iname_idx]
     realiz_inames = flnd.realization_inames[pf_iname_idx]
 
-    start_index, stop_index = flnd.kernel.get_projected_bounds(pf_iname)
-    try:
-        start_index = int(start_index)
-        stop_index = int(stop_index)
-    except TypeError:
-        raise RuntimeError("loop bounds for prefetch must be "
-                "known statically at code gen time")
+    start_index, stop_index = pf.dim_bounds_by_iname[pf_iname]
 
     dim_length = stop_index-start_index
 
     if realiz_inames is not None:
         # {{{ parallel fetch
 
-        realiz_bounds = [flnd.kernel.get_projected_bounds(rn) for rn in realiz_inames]
-        realiz_lengths = [stop-start for start, stop in realiz_bounds]
+        realiz_bounds = [
+                flnd.kernel.get_bounds(rn, (rn,), allow_parameters=False)
+                for rn in realiz_inames]
+        for realiz_start, realiz_stop, realiz_equality in realiz_bounds:
+            assert not realiz_equality
+
+        realiz_lengths = [stop-start for start, stop, equality in realiz_bounds]
         from pytools import product
         total_realiz_size = product(realiz_lengths)
 
@@ -189,12 +188,13 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         pf_dim_var = "prefetch_dim_idx_%d" % pf_iname_idx
         pf_dim_expr = var(pf_dim_var)
 
-        lb_cns, ub_cns = flnd.kernel.get_projected_bounds_constraints(pf_iname)
+        lb_cns, ub_cns = pf.get_dim_bounds_constraints_by_iname(pf_iname)
+
         import islpy as isl
         from loopy.isl import cast_constraint_to_space
         loop_slab = (isl.Set.universe(flnd.kernel.space)
-                .add_constraint(cast_constraint_to_space(lb_cns, kernel.space))
-                .add_constraint(cast_constraint_to_space(ub_cns, kernel.space)))
+                .add_constraints([cast_constraint_to_space(cns, kernel.space)
+                    for cns in [lb_cns, ub_cns]]))
         new_impl_domain = implemented_domain.intersect(loop_slab)
 
         pf_idx_subst_map = pf_idx_subst_map.copy()
diff --git a/loopy/compiled.py b/loopy/compiled.py
index f8660daf5..bd72ecdaf 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -49,8 +49,10 @@ class CompiledKernel:
             self.size_args = size_args
 
         from loopy.kernel import TAG_GROUP_IDX, TAG_WORK_ITEM_IDX
-        gsize_expr = tuple(self.kernel.tag_type_lengths(TAG_GROUP_IDX))
-        lsize_expr = tuple(self.kernel.tag_type_lengths(TAG_WORK_ITEM_IDX))
+        gsize_expr = tuple(self.kernel.tag_type_lengths(
+            TAG_GROUP_IDX, allow_parameters=True))
+        lsize_expr = tuple(self.kernel.tag_type_lengths(
+            TAG_WORK_ITEM_IDX, allow_parameters=False))
 
         if not gsize_expr: gsize_expr = (1,)
         if not lsize_expr: lsize_expr = (1,)
diff --git a/loopy/isl.py b/loopy/isl.py
index c232e71c6..e31575e2d 100644
--- a/loopy/isl.py
+++ b/loopy/isl.py
@@ -6,163 +6,9 @@ from islpy import dim_type
 
 
 
-# {{{ expression -> constraint conversion
-
-def _constraint_from_expr(space, expr, constraint_factory):
-    from loopy.symbolic import CoefficientCollector
-    return constraint_factory(space,
-            CoefficientCollector()(expr))
-
-def eq_constraint_from_expr(space, expr):
-    return _constraint_from_expr(
-            space, expr, isl.Constraint.eq_from_names)
-
-def ineq_constraint_from_expr(space, expr):
-    return _constraint_from_expr(
-            space, expr, isl.Constraint.ineq_from_names)
-
-# }}}
-
 
 # {{{ isl helpers
 
-def get_bounds_constraints(bset, iname, space=None, admissible_vars=None):
-    if isinstance(bset, isl.Set):
-        bset, = bset.get_basic_sets()
-
-    constraints = bset.get_constraints()
-
-    if not isinstance(admissible_vars, set):
-        admissible_vars = set(admissible_vars)
-
-    lower = []
-    upper = []
-    equality = []
-
-    if space is None:
-        space = bset.get_dim()
-
-    var_dict = space.get_var_dict()
-    iname_tp, iname_idx = var_dict[iname]
-
-    for cns in constraints:
-        iname_coeff = int(cns.get_coefficient(iname_tp, iname_idx))
-
-        if admissible_vars is not None:
-            if not (set(cns.get_coefficients_by_name().iterkeys())
-                    <= admissible_vars):
-                continue
-
-        if iname_coeff == 0:
-            continue
-
-        if cns.is_equality():
-            equality.append(cns)
-        elif iname_coeff < 0:
-            upper.append(cns)
-        else: #  iname_coeff > 0
-            lower.append(cns)
-
-    return lower, upper, equality
-
-
-def get_projected_bounds_constraints(set, iname):
-    """Get an overapproximation of the loop bounds for the variable *iname*,
-    as constraints.
-    """
-
-    # project out every variable except iname
-    projected_domain = isl.project_out_except(set, [iname], [dim_type.set])
-
-    basic_sets = projected_domain.get_basic_sets()
-
-    # FIXME perhaps use some form of hull here if there's more than one
-    # basic set?
-    bset, = basic_sets
-
-    # Python-style, half-open bounds
-    upper_bounds = []
-    lower_bounds = []
-    bset = bset.remove_divs()
-
-    bset_iname_dim_type, bset_iname_idx = bset.get_dim().get_var_dict()[iname]
-
-    def examine_constraint(cns):
-        assert not cns.is_equality()
-        assert not cns.is_div_constraint()
-
-        coeffs = cns.get_coefficients_by_name()
-
-        iname_coeff = int(coeffs.get(iname, 0))
-        if iname_coeff == 0:
-            return
-        elif iname_coeff < 0:
-            upper_bounds.append(cns)
-        else: # iname_coeff > 0:
-            lower_bounds.append(cns)
-
-    bset.foreach_constraint(examine_constraint)
-
-    lb, = lower_bounds
-    ub, = upper_bounds
-
-    return lb, ub
-
-
-
-
-def solve_constraint_for_bound(cns, iname):
-    rhs, iname_coeff = constraint_to_expr(cns, except_name=iname)
-
-    if iname_coeff == 0:
-        raise ValueError("cannot solve constraint for '%s'--"
-                "constraint does not contain variable"
-                % iname)
-
-    from pymbolic import expand
-    from pytools import div_ceil
-    from pymbolic import flatten
-    from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-    cfm = CommutativeConstantFoldingMapper()
-
-    if iname_coeff > 0 or cns.is_equality():
-        if cns.is_equality():
-            kind = "=="
-        else:
-            kind = ">="
-
-        return kind, cfm(flatten(div_ceil(expand(-rhs), iname_coeff)))
-    else: # iname_coeff < 0
-        from pytools import div_ceil
-        return "<", cfm(flatten(div_ceil(rhs+1, -iname_coeff)))
-
-
-
-
-def get_projected_bounds(set, iname):
-    """Get an overapproximation of the loop bounds for the variable *iname*,
-    as actual bounds.
-    """
-
-    lb_cns, ub_cns = get_projected_bounds_constraints(set, iname)
-
-    for cns in [lb_cns, ub_cns]:
-        iname_tp, iname_idx = lb_cns.get_dim().get_var_dict()[iname]
-        iname_coeff = cns.get_coefficient(iname_tp, iname_idx)
-
-        if iname_coeff == 0:
-            continue
-
-        kind, bound = solve_constraint_for_bound(cns, iname)
-        if kind == "<":
-            ub = bound
-        elif kind == ">=":
-            lb = bound
-        else:
-            raise ValueError("unsupported constraint kind")
-
-    return lb, ub
-
 def cast_constraint_to_space(cns, new_space, as_equality=None):
     if as_equality is None:
         as_equality = cns.is_equality()
@@ -197,15 +43,9 @@ def copy_constraint(cns, as_equality=None):
     return cast_constraint_to_space(cns, cns.get_dim(),
             as_equality=as_equality)
 
-def get_dim_bounds(set, inames):
-    vars = set.get_dim().get_var_dict(dim_type.set).keys()
-    return [get_projected_bounds(set, v) for v in inames]
-
-def count_box_from_bounds(bounds):
-    from pytools import product
-    return product(stop-start for start, stop in bounds)
-
 def make_index_map(set, index_expr):
+    from loopy.symbolic import eq_constraint_from_expr
+
     if not isinstance(index_expr, tuple):
         index_expr = (index_expr,)
 
@@ -227,6 +67,7 @@ def make_index_map(set, index_expr):
     return amap
 
 def make_slab(space, iname, start, stop):
+    from loopy.symbolic import ineq_constraint_from_expr
     from pymbolic import var
     var_iname = var(iname)
     return (isl.Set.universe(space)
@@ -237,24 +78,6 @@ def make_slab(space, iname, start, stop):
             .add_constraint(ineq_constraint_from_expr(
                 space, stop-1 - var_iname)))
 
-def constraint_to_expr(cns, except_name=None):
-    excepted_coeff = 0
-    result = 0
-    from pymbolic import var
-    for var_name, coeff in cns.get_coefficients_by_name().iteritems():
-        if isinstance(var_name, str):
-            if var_name == except_name:
-                excepted_coeff = int(coeff)
-            else:
-                result += int(coeff)*var(var_name)
-        else:
-            result += int(coeff)
-
-    if except_name is not None:
-        return result, excepted_coeff
-    else:
-        return result
-
 # }}}
 
 # vim: foldmethod=marker
diff --git a/loopy/kernel.py b/loopy/kernel.py
index b8f9d9753..e5be621b0 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -278,27 +278,29 @@ class LoopKernel(Record):
                 result.append(dim)
 
     @memoize_method
-    def get_projected_bounds_constraints(self, iname):
+    def get_bounds_constraints(self, iname, admissible_vars, allow_parameters):
         """Get an overapproximation of the loop bounds for the variable *iname*."""
 
-        from loopy.isl import get_projected_bounds_constraints
-        return get_projected_bounds_constraints(self.domain, iname)
+        from loopy.codegen.bounds import get_bounds_constraints
+        return get_bounds_constraints(self.domain, iname, admissible_vars,
+                allow_parameters)
 
     @memoize_method
-    def get_projected_bounds(self, iname):
+    def get_bounds(self, iname, admissible_vars, allow_parameters):
         """Get an overapproximation of the loop bounds for the variable *iname*."""
 
-        from loopy.isl import get_projected_bounds
-        return get_projected_bounds(self.domain, iname)
+        from loopy.codegen.bounds import get_bounds
+        return get_bounds(self.domain, iname, admissible_vars, allow_parameters)
 
-    def tag_type_lengths(self, tag_cls):
+    def tag_type_lengths(self, tag_cls, allow_parameters):
         def get_length(iname):
             tag = self.iname_to_tag[iname]
             if tag.forced_length is not None:
                 return tag.forced_length
 
-            start, stop = self.get_projected_bounds(iname)
-            return stop-start
+            lower, upper, equality = self.get_bounds(iname, (iname,), 
+                    allow_parameters=allow_parameters)
+            return upper-lower
 
         return [get_length(iname)
                 for iname in self.ordered_inames_by_tag_type(tag_cls)]
@@ -458,7 +460,7 @@ class LoopKernel(Record):
                 .substitute(name, new_loop_index)
                 .copy(domain=new_domain, iname_to_tag=new_iname_to_tag))
 
-    def get_problems(self, emit_warnings=True):
+    def get_problems(self, parameters, emit_warnings=True):
         """
         :return: *(max_severity, list of (severity, msg))*, where *severity* ranges from 1-5.
             '5' means 'will certainly not run'.
@@ -473,8 +475,13 @@ class LoopKernel(Record):
 
             msgs.append((severity, s))
 
-        glens = self.tag_type_lengths(TAG_GROUP_IDX)
-        llens = self.tag_type_lengths(TAG_WORK_ITEM_IDX)
+        glens = self.tag_type_lengths(TAG_GROUP_IDX, allow_parameters=True)
+        llens = self.tag_type_lengths(TAG_WORK_ITEM_IDX, allow_parameters=False)
+
+        from pymbolic import evaluate
+        glens = evaluate(glens, parameters)
+        llens = evaluate(llens, parameters)
+
         if (max(len(glens), len(llens))
                 > self.device.max_work_item_dimensions):
             msg(5, "too many work item dimensions")
diff --git a/loopy/prefetch.py b/loopy/prefetch.py
index 9e2bef57b..4f5b436c0 100644
--- a/loopy/prefetch.py
+++ b/loopy/prefetch.py
@@ -1,7 +1,6 @@
 from __future__ import division
 
 from pytools import Record, memoize_method
-import islpy as isl
 from islpy import dim_type
 
 
@@ -90,8 +89,10 @@ class LocalMemoryPrefetch(Record):
     @property
     @memoize_method
     def domain(self):
-        return (isl.project_out_except(self.kernel.domain, self.inames, [dim_type.set])
-                .remove_divs())
+        return (self.kernel.domain
+                .project_out_except(self.inames, [dim_type.set])
+                .compute_divs()
+                .remove_divs_of_dim_type(dim_type.set))
 
     @property
     @memoize_method
@@ -106,20 +107,55 @@ class LocalMemoryPrefetch(Record):
     def restricted_index_map(self):
         return self.index_map.intersect_domain(self.domain)
 
+    @memoize_method
+    def get_dim_bounds_constraints_by_iname(self, iname):
+        from loopy.codegen.bounds import get_bounds_constraints
+        lower, upper, equality = get_bounds_constraints(
+                self.domain, iname, (iname,),
+                allow_parameters=False)
+
+        assert not equality
+
+        lower, = lower
+        upper, = upper
+        return lower, upper
+
+    @property
+    @memoize_method
+    def dim_bounds_by_iname(self):
+        from loopy.codegen.bounds import solve_constraint_for_bound
+        result = {}
+        for iname in self.inames:
+            lower, upper = self.get_dim_bounds_constraints_by_iname(iname)
+
+            lower_kind, lower_bound = solve_constraint_for_bound(lower, iname)
+            upper_kind, upper_bound = solve_constraint_for_bound(upper, iname)
+
+            try:
+                lower_bound = int(lower_bound)
+                upper_bound = int(upper_bound)
+            except TypeError:
+                raise RuntimeError("loop bounds for prefetch must be known statically")
+
+            result[iname] = (lower_bound, upper_bound)
+
+        return result
+
     @property
     @memoize_method
     def dim_bounds(self):
-        from loopy.isl import get_dim_bounds
-        return get_dim_bounds(self.domain, self.inames)
+        dbbi = self.dim_bounds_by_iname
+        return [dbbi[iname] for iname in self.inames]
 
     @property
     def itemsize(self):
         return self.kernel.arg_dict[self.input_vector].dtype.itemsize
+
     @property
     @memoize_method
     def nbytes(self):
-        from loopy.isl import count_box_from_bounds
-        return self.itemsize * count_box_from_bounds(self.dim_bounds)
+        from pytools import product
+        return self.itemsize * product(upper-lower for lower, upper in self.dim_bounds)
 
     @memoize_method
     def free_variables(self):
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index d318316a7..3056d8d03 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -6,6 +6,7 @@ from pymbolic.mapper import CombineMapper, RecursiveMapper
 from pymbolic.mapper.c_code import CCodeMapper
 from pymbolic.mapper.stringifier import PREC_NONE
 import numpy as np
+import islpy as isl
 
 
 
@@ -145,7 +146,7 @@ class LoopyCCodeMapper(CCodeMapper):
                 from pymbolic.mapper.stringifier import PREC_SUM
                 return pf.name+"".join(
                         "[%s - %s]" % (iname, self.rec(
-                            self.kernel.get_projected_bounds(iname)[0],
+                            pf.dim_bounds_by_iname[iname][0],
                             PREC_SUM))
                         for iname in pf.inames)
 
@@ -199,6 +200,59 @@ class LoopyCCodeMapper(CCodeMapper):
                     % (self.rec(expr.numerator, PREC_NONE),
                         self.rec(expr.denominator, PREC_NONE)))
 
+    def map_min(self, expr, prec):
+        what = type(expr).__name__.lower()
+
+        children = expr.children[:]
+
+        result = self.rec(children.pop(), PREC_NONE)
+        while children:
+            result = "%s(%s, %s)" % (what,
+                        self.rec(children.pop(), PREC_NONE),
+                        result)
+
+        return result
+
+    map_max = map_min
+
+# }}}
+
+# {{{ expression <-> constraint conversion
+
+def _constraint_from_expr(space, expr, constraint_factory):
+    from loopy.symbolic import CoefficientCollector
+    return constraint_factory(space,
+            CoefficientCollector()(expr))
+
+def eq_constraint_from_expr(space, expr):
+    return _constraint_from_expr(
+            space, expr, isl.Constraint.eq_from_names)
+
+def ineq_constraint_from_expr(space, expr):
+    return _constraint_from_expr(
+            space, expr, isl.Constraint.ineq_from_names)
+
+def constraint_to_expr(cns, except_name=None):
+    excepted_coeff = 0
+    result = 0
+    from pymbolic import var
+    for var_name, coeff in cns.get_coefficients_by_name().iteritems():
+        if isinstance(var_name, str):
+            if var_name == except_name:
+                excepted_coeff = int(coeff)
+            else:
+                result += int(coeff)*var(var_name)
+        else:
+            result += int(coeff)
+
+    if except_name is not None:
+        return result, excepted_coeff
+    else:
+        return result
+
 # }}}
 
+
+
+
 # vim: foldmethod=marker
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 2618745cd..5fe5faaa5 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -104,7 +104,7 @@ def test_plain_matrix_mul(ctx_factory):
     knl = lp.split_dimension(knl, "k", 4)
     knl = lp.add_prefetch(knl, 'a', ["k_inner", "i_inner"])
     knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner", ])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems({})[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))
@@ -158,7 +158,7 @@ def test_image_matrix_mul(ctx_factory):
     # conflict-free
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner"])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems({})[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))
@@ -216,7 +216,7 @@ def test_image_matrix_mul_ilp(ctx_factory):
     # conflict-free
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems({})[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))
@@ -273,7 +273,7 @@ def test_fancy_matrix_mul(ctx_factory):
     knl = lp.split_dimension(knl, "k", 16)
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["k_inner", "j_inner"])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems(dict(n=n))[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))
@@ -357,7 +357,7 @@ def test_dg_matrix_mul(ctx_factory):
         knl = lp.add_prefetch(knl, 'fld%d' % ifld,
                 #["k_inner_outer", "k_inner_inner", "j"])
                 ["k_inner", "j"])
-    assert knl.get_problems()[0] <= 2
+    assert knl.get_problems({})[0] <= 2
 
     kernel_gen = list(lp.insert_register_prefetches(knl)
             for knl in lp.generate_loop_schedules(knl))[:1]
-- 
GitLab