From 988c7bfd686b4cc79280545272681683fb8979c5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 3 Sep 2012 02:18:58 -0400
Subject: [PATCH] Fix conditional generation.

Unify two pre-existing code paths for conditional generation.
Compute bounds for 'for' by isl bounds finding.
Use overapproximation *except* when the insn builds its own
  final conditional.
Btw, the axpy nondet failure was a PyOpenCL bug in Array.view.
  (fixed in pyopencl:9207aae)
---
 MEMO                         |   2 -
 loopy/check.py               |  10 +-
 loopy/codegen/__init__.py    |  12 +-
 loopy/codegen/bounds.py      | 249 +++--------------------------------
 loopy/codegen/control.py     |  10 +-
 loopy/codegen/instruction.py |  25 +++-
 loopy/codegen/loop.py        | 120 ++++++++++++-----
 loopy/symbolic.py            |   7 +
 8 files changed, 163 insertions(+), 272 deletions(-)

diff --git a/MEMO b/MEMO
index 504788e49..8a972d247 100644
--- a/MEMO
+++ b/MEMO
@@ -47,8 +47,6 @@ To-do
 
 - Test join_inames
 
-- Debug axpy nondet fail
-
 - Make tests run on GPUs
 
 Fixes:
diff --git a/loopy/check.py b/loopy/check.py
index 51b949aa9..baef5d37f 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -296,7 +296,7 @@ def run_automatic_checks(kernel):
 
 # {{{ sanity-check for implemented domains of each instruction
 
-def check_implemented_domains(kernel, implemented_domains):
+def check_implemented_domains(kernel, implemented_domains, code=None):
     from islpy import dim_type
 
     from islpy import align_spaces, align_two
@@ -358,6 +358,14 @@ def check_implemented_domains(kernel, implemented_domains):
                 lines.append(
                         "sample point %s: %s" % (kind, ", ".join(point_axes)))
 
+            if code is not None:
+                print 79*"-"
+                print "CODE:"
+                print 79*"-"
+                from loopy.compiled import get_highlighted_code
+                print get_highlighted_code(code)
+                print 79*"-"
+
             raise RuntimeError("sanity check failed--implemented and desired "
                     "domain for instruction '%s' do not match\n\n"
                     "implemented: %s\n\n"
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 69739ed69..e72d00568 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -313,9 +313,6 @@ def generate_code(kernel, with_annotation=False,
                     Value("void", kernel.name), args))),
             body))
 
-    from loopy.check import check_implemented_domains
-    assert check_implemented_domains(kernel, gen_code.implemented_domains)
-
     # {{{ handle preambles
 
     for arg in kernel.args:
@@ -342,7 +339,14 @@ def generate_code(kernel, with_annotation=False,
 
     # }}}
 
-    return str(Module(mod))
+    result = str(Module(mod))
+
+    from loopy.check import check_implemented_domains
+    assert check_implemented_domains(kernel, gen_code.implemented_domains,
+            result)
+
+    return result
+
 
 # }}}
 
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index 983a0f521..a14e871fe 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -6,96 +6,6 @@ from islpy import dim_type
 
 
 
-# {{{ find bounds from set
-
-def get_bounds_constraints(set, iname, admissible_inames, allow_parameters):
-    """May overapproximate."""
-    if admissible_inames is not None or not allow_parameters:
-        if admissible_inames is None:
-            elim_type = []
-        else:
-            assert iname in admissible_inames
-            elim_type = [dim_type.set]
-
-        if not allow_parameters:
-            elim_type.append(dim_type.param)
-
-        set = set.eliminate_except(admissible_inames, elim_type)
-        set = set.remove_divs()
-
-    basic_sets = set.get_basic_sets()
-    if len(basic_sets) > 1:
-        set = set.coalesce()
-        basic_sets = set.get_basic_sets()
-        if len(basic_sets) > 1:
-            raise RuntimeError("got non-convex set in bounds generation")
-
-    bset, = basic_sets
-
-    # FIXME perhaps use some form of hull here if there's more than one
-    # basic set?
-
-    lower = []
-    upper = []
-    equality = []
-
-    space = bset.get_space()
-
-    var_dict = space.get_var_dict()
-    iname_tp, iname_idx = var_dict[iname]
-
-
-    from pytools import any
-    if any(cns.is_div_constraint() for cns in bset.get_constraints()):
-        bset = bset.remove_divs()
-
-    for cns in bset.get_constraints():
-        assert not cns.is_div_constraint()
-
-        iname_coeff = int(cns.get_coefficient(iname_tp, iname_idx))
-
-        if iname_coeff == 0:
-            continue
-
-        if cns.is_equality():
-            equality.append(cns)
-        elif iname_coeff < 0:
-            upper.append(cns)
-        else: #  iname_coeff > 0
-            lower.append(cns)
-
-    return lower, upper, equality
-
-def solve_constraint_for_bound(cns, iname):
-    from loopy.symbolic import constraint_to_expr
-    rhs, iname_coeff = constraint_to_expr(cns, except_name=iname)
-
-    if iname_coeff == 0:
-        raise ValueError("cannot solve constraint for '%s'--"
-                "constraint does not contain variable"
-                % iname)
-
-    from pymbolic import expand
-    from pytools import div_ceil
-    from pymbolic import flatten
-    from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-    cfm = CommutativeConstantFoldingMapper()
-
-    if iname_coeff > 0 or cns.is_equality():
-        if cns.is_equality():
-            kind = "=="
-        else:
-            kind = ">="
-
-        return kind, cfm(flatten(div_ceil(expand(-rhs), iname_coeff)))
-    else: # iname_coeff < 0
-        from pytools import div_ceil
-        return "<", cfm(flatten(div_ceil(rhs+1, -iname_coeff)))
-
-# }}}
-
-# {{{ bounds check generator
-
 def constraint_to_code(ccm, cns):
     if cns.is_equality():
         comp_op = "=="
@@ -105,114 +15,38 @@ def constraint_to_code(ccm, cns):
     from loopy.symbolic import constraint_to_expr
     return "%s %s 0" % (ccm(constraint_to_expr(cns), 'i'), comp_op)
 
-def generate_bounds_checks(domain, check_inames, implemented_domain):
-    """Will not overapproximate."""
+# {{{ bounds check generator
 
+def get_bounds_checks(domain, check_inames, implemented_domain,
+        overapproximate):
     if isinstance(domain, isl.BasicSet):
         domain = isl.Set.from_basic_set(domain)
     domain = domain.remove_redundancies()
-    domain = isl.align_spaces(domain, implemented_domain)
-    result = domain.gist(implemented_domain)
-
-    result = (result
-        .eliminate_except(check_inames, [dim_type.set])
-        .compute_divs())
-
-    from loopy.isl_helpers import convexify
-    return convexify(result).get_constraints()
-
-def wrap_in_bounds_checks(ccm, domain, check_inames, implemented_domain, stmt):
-    bounds_checks = generate_bounds_checks(
-            domain, check_inames,
-            implemented_domain)
+    result = domain.eliminate_except(check_inames, [dim_type.set])
 
-    bounds_check_set = isl.Set.universe(domain.get_space()).add_constraints(bounds_checks)
-    bounds_check_set, new_implemented_domain = isl.align_two(
-            bounds_check_set, implemented_domain)
-    new_implemented_domain = new_implemented_domain & bounds_check_set
-
-    condition_codelets = [
-            constraint_to_code(ccm, cns) for cns in
-            generate_bounds_checks(domain, check_inames, implemented_domain)]
-
-    if condition_codelets:
-        from cgen import If
-        stmt = If("\n&& ".join(condition_codelets), stmt)
-
-    return stmt, new_implemented_domain
-
-def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt,
-        index_dtype):
-    # FIXME add admissible vars
-    if isinstance(constraint_bset, isl.Set):
-        constraint_bset, = constraint_bset.get_basic_sets()
-
-    constraints = constraint_bset.get_constraints()
-
-    from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
-
-    cfm = CommutativeConstantFoldingMapper()
-
-    from loopy.symbolic import constraint_to_expr
-
-    start_exprs = []
-    end_conds = []
-    equality_exprs = []
-
-    for cns in constraints:
-        rhs, iname_coeff = constraint_to_expr(cns, except_name=iname)
-
-        if iname_coeff == 0:
-            continue
-
-        if cns.is_equality():
-            kind, bound = solve_constraint_for_bound(cns, iname)
-            assert kind == "=="
-            equality_exprs.append(bound)
-        elif iname_coeff < 0:
-            from pymbolic import var
-            rhs += iname_coeff*var(iname)
-            end_conds.append("%s >= 0" %
-                    ccm(cfm(rhs), 'i'))
-        else: #  iname_coeff > 0
-            kind, bound = solve_constraint_for_bound(cns, iname)
-            assert kind == ">="
-            start_exprs.append(bound)
-
-    if equality_exprs:
-        assert len(equality_exprs) == 1
+    if overapproximate:
+        # This is ok, because we're really looking for the
+        # projection, with no remaining constraints from
+        # the eliminated variables.
+        result = result.remove_divs()
+    else:
+        result = result.compute_divs()
 
-        equality_expr, = equality_exprs
+    result = isl.align_spaces(result, implemented_domain)
+    result = result.gist(implemented_domain)
 
-        from loopy.codegen import gen_code_block
-        from cgen import Initializer, POD, Const, Line
-        return gen_code_block([
-            Initializer(Const(POD(index_dtype, iname)),
-                ccm(equality_expr, 'i')),
-            Line(),
-            stmt,
-            ])
+    if overapproximate:
+        result = result.remove_divs()
     else:
-        if len(start_exprs) > 1:
-            from pymbolic.primitives import Max
-            start_expr = Max(start_exprs)
-        elif len(start_exprs) == 1:
-            start_expr, = start_exprs
-        else:
-            raise RuntimeError("no starting value found for 'for' loop in '%s'"
-                    % iname)
+        result = result.compute_divs()
 
-        from cgen import For
-        from loopy.codegen import wrap_in
-        return wrap_in(For,
-                "int %s = %s" % (iname, ccm(start_expr, 'i')),
-                " && ".join(end_conds),
-                "++%s" % iname,
-                stmt)
+    from loopy.isl_helpers import convexify
+    result = convexify(result).get_constraints()
+    return result
 
 # }}}
 
-# {{{ on which variables may a conditional depend?
+# {{{ on which inames may a conditional depend?
 
 def get_usable_inames_for_conditional(kernel, sched_index):
     from loopy.schedule import EnterLoop, LeaveLoop
@@ -242,49 +76,6 @@ def get_usable_inames_for_conditional(kernel, sched_index):
 
 # }}}
 
-# {{{ get_simple_loop_bounds
-
-def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain, iname_domain):
-    from loopy.codegen.bounds import (get_bounds_constraints,
-            get_usable_inames_for_conditional)
-    lower_constraints_orig, upper_constraints_orig, equality_constraints_orig = \
-            get_bounds_constraints(iname_domain, iname,
-                    frozenset([iname])
-                    | get_usable_inames_for_conditional(kernel, sched_index+1),
-                    allow_parameters=True)
-
-    lower_constraints_orig.extend(equality_constraints_orig)
-    upper_constraints_orig.extend(equality_constraints_orig)
-    #assert not equality_constraints_orig
-
-    from loopy.codegen.bounds import pick_simple_constraint
-    lb_cns_orig = pick_simple_constraint(lower_constraints_orig, iname)
-    ub_cns_orig = pick_simple_constraint(upper_constraints_orig, iname)
-
-    return lb_cns_orig, ub_cns_orig
-
-# }}}
-
-# {{{ pick_simple_constraint
-
-def pick_simple_constraint(constraints, iname):
-    if len(constraints) == 0:
-        raise RuntimeError("no constraint for '%s'" % iname)
-    elif len(constraints) == 1:
-        return constraints[0]
-
-    from pymbolic.mapper.flop_counter import FlopCounter
-    count_flops = FlopCounter()
-
-    from pytools import argmin2
-    return argmin2(
-            (cns, count_flops(solve_constraint_for_bound(cns, iname)[1]))
-            for cns in constraints)
-
-# }}}
-
-
-
 
 
 
diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py
index 5dfea65cc..db1c4f4a0 100644
--- a/loopy/codegen/control.py
+++ b/loopy/codegen/control.py
@@ -174,9 +174,13 @@ def build_loop_nest(kernel, sched_index, codegen_state):
             domain = isl.align_spaces(
                     self.kernel.get_inames_domain(check_inames),
                     self.impl_domain, obj_bigger_ok=True)
-            from loopy.codegen.bounds import generate_bounds_checks
-            return generate_bounds_checks(domain,
-                    check_inames, self.impl_domain)
+            from loopy.codegen.bounds import get_bounds_checks
+            return get_bounds_checks(domain,
+                    check_inames, self.impl_domain,
+
+                    # Each instruction individually gets its bounds checks,
+                    # so we can safely overapproximate here.
+                    overapproximate=True)
 
     def build_insn_group(sched_indices_and_cond_inames, codegen_state, done_group_lengths=set()):
         # done_group_lengths serves to prevent infinite recursion by imposing a
diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py
index 909f02fe2..617af6009 100644
--- a/loopy/codegen/instruction.py
+++ b/loopy/codegen/instruction.py
@@ -1,6 +1,30 @@
 """Code generation for Instruction objects."""
 from __future__ import division
 
+import islpy as isl
+
+
+
+
+def wrap_in_bounds_checks(ccm, domain, check_inames, implemented_domain, stmt):
+    from loopy.codegen.bounds import get_bounds_checks, constraint_to_code
+    bounds_checks = get_bounds_checks(
+            domain, check_inames,
+            implemented_domain, overapproximate=False)
+
+    bounds_check_set = isl.Set.universe(domain.get_space()).add_constraints(bounds_checks)
+    bounds_check_set, new_implemented_domain = isl.align_two(
+            bounds_check_set, implemented_domain)
+    new_implemented_domain = new_implemented_domain & bounds_check_set
+
+    condition_codelets = [constraint_to_code(ccm, cns) for cns in bounds_checks]
+
+    if condition_codelets:
+        from cgen import If
+        stmt = If("\n&& ".join(condition_codelets), stmt)
+
+    return stmt, new_implemented_domain
+
 
 
 
@@ -25,7 +49,6 @@ def generate_instruction_code(kernel, insn, codegen_state):
             ccm(insn.assignee, prec=None, type_context=None),
             ccm(expr, prec=None, type_context=dtype_to_type_context(target_dtype)))
 
-    from loopy.codegen.bounds import wrap_in_bounds_checks
     insn_inames = kernel.insn_inames(insn)
     insn_code, impl_domain = wrap_in_bounds_checks(
             ccm, kernel.get_inames_domain(insn_inames), insn_inames,
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index bbb9c1666..650ce34f4 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -2,6 +2,7 @@ from __future__ import division
 
 from loopy.codegen import gen_code_block
 import islpy as isl
+from islpy import dim_type
 from loopy.codegen.control import build_loop_nest
 
 
@@ -16,13 +17,11 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
     if iname_domain.is_empty():
         return ()
 
-    from loopy.codegen.bounds import get_simple_loop_bounds
-    lb_cns_orig, ub_cns_orig = get_simple_loop_bounds(kernel, sched_index, iname,
-            codegen_state.implemented_domain, iname_domain)
-
-    space = lb_cns_orig.space
+    space = iname_domain.space
 
     lower_incr, upper_incr = kernel.iname_slab_increments.get(iname, (0, 0))
+    lower_bulk_bound = None
+    upper_bulk_bound = None
 
     if lower_incr or upper_incr:
         bounds = kernel.get_iname_bounds(iname)
@@ -40,17 +39,11 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
         (_, lower_bound_aff), = lower_bound_pw_aff_pieces
         (_, upper_bound_aff), = upper_bound_pw_aff_pieces
 
-        lower_bulk_bound = lb_cns_orig
-        upper_bulk_bound = lb_cns_orig
-
         from loopy.isl_helpers import iname_rel_aff
 
-
         if lower_incr:
             assert lower_incr > 0
             lower_slab = ("initial", isl.BasicSet.universe(space)
-                    .add_constraint(lb_cns_orig)
-                    .add_constraint(ub_cns_orig)
                     .add_constraint(
                         isl.Constraint.inequality_from_aff(
                             iname_rel_aff(kernel.space,
@@ -65,8 +58,6 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
         if upper_incr:
             assert upper_incr > 0
             upper_slab = ("final", isl.BasicSet.universe(space)
-                    .add_constraint(lb_cns_orig)
-                    .add_constraint(ub_cns_orig)
                     .add_constraint(
                         isl.Constraint.inequality_from_aff(
                             iname_rel_aff(space,
@@ -82,21 +73,19 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
         if lower_slab:
             slabs.append(lower_slab)
-        slabs.append((
-            ("bulk",
-                (isl.BasicSet.universe(space)
-                    .add_constraint(lower_bulk_bound)
-                    .add_constraint(upper_bulk_bound)))))
+        bulk_slab = isl.BasicSet.universe(space)
+        if lower_bulk_bound is not None:
+            bulk_slab = bulk_slab.add_constraint(lower_bulk_bound)
+        if upper_bulk_bound is not None:
+            bulk_slab = bulk_slab.add_constraint(upper_bulk_bound)
+        slabs.append(("bulk", bulk_slab))
         if upper_slab:
             slabs.append(upper_slab)
 
         return slabs
 
     else:
-        return [("bulk",
-            (isl.BasicSet.universe(space)
-            .add_constraint(lb_cns_orig)
-            .add_constraint(ub_cns_orig)))]
+        return [("bulk", (isl.BasicSet.universe(space)))]
 
 # }}}
 
@@ -206,7 +195,7 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state, hw_inames_left=
             cmt = None
 
         # Have the conditional infrastructure generate the
-        # slabbin conditionals.
+        # slabbing conditionals.
         slabbed_kernel = intersect_kernel_with_slab(kernel, slab, iname)
 
         inner = set_up_hw_parallel_loops(
@@ -222,32 +211,99 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state, hw_inames_left=
 
 def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
     ccm = codegen_state.c_code_mapper
-    iname = kernel.schedule[sched_index].iname
+    loop_iname = kernel.schedule[sched_index].iname
 
     slabs = get_slab_decomposition(
-            kernel, iname, sched_index, codegen_state)
+            kernel, loop_iname, sched_index, codegen_state)
+
+    from loopy.codegen.bounds import get_usable_inames_for_conditional
+
+    # Note: this does note include loop_iname itself!
+    usable_inames = get_usable_inames_for_conditional(kernel, sched_index)
+    domain = kernel.get_inames_domain(loop_iname)
+
+    # move inames that are usable into parameters
+    for iname in domain.get_var_names(dim_type.set):
+        if iname in usable_inames:
+            dt, idx = domain.get_var_dict()[iname]
+            domain = domain.move_dims(
+                    dim_type.param, domain.dim(dim_type.param),
+                    dt, idx, 1)
+
 
     result = []
 
     for slab_name, slab in slabs:
-        cmt = "%s slab for '%s'" % (slab_name, iname)
+        cmt = "%s slab for '%s'" % (slab_name, loop_iname)
         if len(slabs) == 1:
             cmt = None
 
-        # Conditionals for slab are generated below.
-        new_codegen_state = codegen_state.intersect(slab)
+        # {{{ find bounds
+
+        domain = isl.align_spaces(domain, slab, across_dim_types=True,
+                obj_bigger_ok=True)
+        dom_and_slab = domain & slab
+        _, loop_iname_idx = domain.get_var_dict()[loop_iname]
+        lbound = kernel.cache_manager.dim_min(
+                dom_and_slab, loop_iname_idx).coalesce()
+        ubound = kernel.cache_manager.dim_max(
+                dom_and_slab, loop_iname_idx).coalesce()
+
+        from loopy.isl_helpers import (
+                static_min_of_pw_aff,
+                static_max_of_pw_aff)
+
+        lbound = static_min_of_pw_aff(lbound,
+                constants_only=False)
+        ubound = static_max_of_pw_aff(ubound,
+                constants_only=False)
+
+        # }}}
+
+        # {{{ find implemented slab, build inner code
+
+        from loopy.isl_helpers import iname_rel_aff
+        impl_slab = (
+                isl.BasicSet.universe(domain.space)
+                .add_constraint(
+                    isl.Constraint.inequality_from_aff(
+                        iname_rel_aff(domain.space,
+                            loop_iname, ">=", lbound)))
+                .add_constraint(
+                    isl.Constraint.inequality_from_aff(
+                        iname_rel_aff(domain.space,
+                            loop_iname, "<=", ubound))))
+
+        new_codegen_state = codegen_state.intersect(impl_slab)
 
         inner = build_loop_nest(kernel, sched_index+1,
                 new_codegen_state)
 
-        from loopy.codegen.bounds import wrap_in_for_from_constraints
+        # }}}
 
         if cmt is not None:
             from cgen import Comment
             result.append(Comment(cmt))
-        result.append(
-                wrap_in_for_from_constraints(ccm, iname, slab, inner,
-                    kernel.index_dtype))
+
+        from cgen import Initializer, POD, Const, Line, For
+        from loopy.symbolic import aff_to_expr
+
+        if (ubound - lbound).plain_is_zero():
+            # single-trip, generate just a variable assignment, not a loop
+            result.append(gen_code_block([
+                Initializer(Const(POD(kernel.index_dtype, loop_iname)),
+                    ccm(aff_to_expr(lbound), "i")),
+                Line(),
+                inner,
+                ]))
+
+        else:
+            from loopy.codegen import wrap_in
+            result.append(wrap_in(For,
+                    "int %s = %s" % (loop_iname, ccm(aff_to_expr(lbound), "i")),
+                    "%s <= %s" % (loop_iname, ccm(aff_to_expr(ubound), "i")),
+                    "++%s" % loop_iname,
+                    inner))
 
     return gen_code_block(result)
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 5d6a6c5f4..1a107cd9a 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -482,6 +482,13 @@ def ineq_constraint_from_expr(space, expr):
     return isl.Constraint.inequality_from_aff(aff_from_expr(space,expr))
 
 def constraint_to_expr(cns, except_name=None):
+    # Looks like this is ok after all--get_aff() performs some magic.
+    # Not entirely sure though... FIXME
+    #
+    #ls = cns.get_local_space()
+    #if ls.dim(dim_type.div):
+        #raise RuntimeError("constraint has an existentially quantified variable")
+
     return aff_to_expr(cns.get_aff(), except_name=except_name)
 
 # }}}
-- 
GitLab