From e86c0f6d27bd9f757a40adfd354c1e0c09cded2b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 19 Jul 2011 19:25:41 -0500
Subject: [PATCH] Add conditional avoidance.

---
 examples/matrix-ops.py |  24 ++--
 loopy/__init__.py      | 294 ++++++++++++++++++++++++++++++-----------
 2 files changed, 231 insertions(+), 87 deletions(-)

diff --git a/examples/matrix-ops.py b/examples/matrix-ops.py
index ee04faafa..d5f5bea2a 100644
--- a/examples/matrix-ops.py
+++ b/examples/matrix-ops.py
@@ -73,24 +73,24 @@ def fancy_matrix_mul(ctx_factory=cl.create_some_context):
     queue = cl.CommandQueue(ctx,
             properties=cl.command_queue_properties.PROFILING_ENABLE)
 
-    n = 16*10
+    n = 16*30
     from pymbolic import var
     a, b, c, i, j, k, n_sym = [var(s) for s in "abcijkn"]
 
     knl = lp.LoopKernel(ctx.devices[0],
-        "[n] -> {[i,j,k]: 0<=i,j,k<n}",
-        [
+            "[n] -> {[i,j,k]: 0<=i,j,k<n and n>0 }",
+            [
                 (c[i, j], a[i, k]*b[k, j])
                 ],
-        [
-            lp.ArrayArg("a", dtype, shape=(n_sym, n_sym)),
-            lp.ArrayArg("b", dtype, shape=(n_sym, n_sym)),
-            lp.ArrayArg("c", dtype, shape=(n_sym, n_sym)),
-            lp.ScalarArg("n", np.int32, approximately=1000),
-        ], name="fancy_matmul")
+            [
+                lp.ArrayArg("a", dtype, shape=(n_sym, n_sym)),
+                lp.ArrayArg("b", dtype, shape=(n_sym, n_sym)),
+                lp.ArrayArg("c", dtype, shape=(n_sym, n_sym)),
+                lp.ScalarArg("n", np.int32, approximately=1000),
+                ], name="fancy_matmul")
 
     knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.1")
-    knl = lp.split_dimension(knl, "j", 17, outer_tag="g.1", inner_tag="l.0")
+    knl = lp.split_dimension(knl, "j", 16, outer_tag="g.1", inner_tag="l.0")
     knl = lp.split_dimension(knl, "k", 16)
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["k_inner", "j_inner"])
@@ -111,8 +111,8 @@ def fancy_matrix_mul(ctx_factory=cl.create_some_context):
         if check:
             sol = c.get()
             import matplotlib.pyplot as pt
-            pt.imshow(refsol-sol)
-            pt.show()
+            #pt.imshow(refsol-sol)
+            #pt.show()
             rel_err = la.norm(refsol-sol, "fro")/la.norm(refsol, "fro")
             assert rel_err < 1e-5, rel_err
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 9def9d5b8..8581b8423 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -23,7 +23,6 @@ register_mpz_with_pymbolic()
 
 
 # TODO: Divisibility
-# TODO: Multi-D array access
 # TODO: Non-multiple loop splits
 #       FIXME: Splitting an uneven-split loop?
 # TODO: nD Texture access
@@ -32,6 +31,7 @@ register_mpz_with_pymbolic()
 # TODO: Try different kernels
 # TODO:   - Tricky: Convolution, FD
 # TODO: Try, fix indirect addressing
+# TODO: Recast slab checking in terms of # of conditionals
 
 # TODO: Custom reductions per red. axis
 # TODO: Vectorize
@@ -117,6 +117,7 @@ def get_bounds_constraints(set, iname):
     bset_iname_dim_type, bset_iname_idx = bset.get_dim().get_var_dict()[iname]
 
     def examine_constraint(cns):
+        assert not cns.is_equality()
         coeffs = cns.get_coefficients_by_name()
 
         iname_coeff = int(coeffs.get(iname, 0))
@@ -170,7 +171,8 @@ def get_bounds(set, iname):
             from pytools import div_ceil
             ub = cfm(flatten(div_ceil(rhs+1, -iname_coeff)))
         else: #  iname_coeff > 0
-            lb = cfm(flatten(rhs//iname_coeff))
+            from pymbolic import expand
+            lb = cfm(flatten(expand(-rhs)//iname_coeff))
 
     return lb, ub
 
@@ -181,6 +183,29 @@ def cast_constraint_to_space(cns, new_space):
         factory = isl.Constraint.ineq_from_names
     return factory(new_space, cns.get_coefficients_by_name())
 
+def block_shift_constraint(cns, iname, multiple):
+    cns = copy_constraint(cns)
+    cns.set_constant(cns.get_constant()
+            + cns.get_coefficients_by_name()[iname]*multiple)
+    return cns
+
+def negate_constraint(cns):
+    assert not cns.is_equality()
+    # FIXME hackety hack
+    my_set = (isl.BasicSet.universe(cns.get_dim())
+            .add_constraint(cns))
+    my_set = my_set.complement()
+
+    results = []
+    def examine_basic_set(s):
+        s.foreach_constraint(results.append)
+    my_set.foreach_basic_set(examine_basic_set)
+    result, = results
+    return result
+
+def copy_constraint(cns):
+    return cast_constraint_to_space(cns, cns.get_dim())
+
 def get_dim_bounds(set):
     vars = set.get_dim().get_var_dict(dim_type.set).keys()
     return [get_bounds(set, v) for v in vars]
@@ -942,6 +967,52 @@ def insert_register_prefetches(kernel):
 
 # {{{ code generation
 
+class GeneratedCode(Record):
+    __slots__ = ["ast", "num_conditionals"]
+
+def gen_code_block(elements):
+    from cgen import Generable, Block
+
+    num_conditionals = 0
+    block_els = []
+    for el in elements:
+        if isinstance(el, GeneratedCode):
+            num_conditionals = num_conditionals + el.num_conditionals
+            block_els.append(el.ast)
+        elif isinstance(el, Generable):
+            block_els.append(el)
+        else:
+            raise ValueError("unidentifiable object in block")
+
+    if len(block_els) == 1:
+        ast, = block_els
+    else:
+        ast = Block(block_els)
+    return GeneratedCode(ast=ast, num_conditionals=num_conditionals)
+
+def wrap_with(cls, *args):
+    inner = args[-1]
+    args = args[:-1]
+
+    from cgen import If, Generable
+
+    if isinstance(inner, GeneratedCode):
+        num_conditionals = inner.num_conditionals
+        ast = inner.ast
+    elif isinstance(inner, Generable):
+        num_conditionals = 0
+        ast = inner
+
+    args = args + (ast,)
+    ast = cls(*args)
+
+    if isinstance(ast, If):
+        import re
+        cond_joiner_re = re.compile(r"\|\||\&\&")
+        num_conditionals += len(cond_joiner_re.split(ast.condition))
+
+    return GeneratedCode(ast=ast, num_conditionals=num_conditionals)
+
 # {{{ C code mapper
 
 class LoopyCCodeMapper(CCodeMapper):
@@ -1057,7 +1128,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         from pytools import product
         total_realiz_size = product(realiz_lengths)
 
-        result = None
+        result = []
 
         cur_index = 0
 
@@ -1084,20 +1155,15 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
                     new_impl_domain)
 
             if cur_index+total_realiz_size > dim_length:
-                inner = If(
+                inner = wrap_with(If,
                         "%s < %s" % (ccm(pf_dim_expr), stop_index),
                         inner)
 
-            if result is None:
-                result = inner
-            elif isinstance(result, Block):
-                result.append(inner)
-            else:
-                result = Block([result, inner])
+            result.append(inner)
 
             cur_index += total_realiz_size
 
-        return result
+        return gen_code_block(result)
 
         # }}}
     else:
@@ -1107,7 +1173,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         pf_dim_expr = var(pf_dim_var)
 
         lb_cns, ub_cns = flnd.kernel.get_bounds_constraints(pf_iname)
-        loop_slab = (isl.Set.universe(kernel.space)
+        loop_slab = (isl.Set.universe(flnd.kernel.space)
                 .add_constraint(lb_cns)
                 .add_constraint(ub_cns))
         new_impl_domain = implemented_domain.intersect(loop_slab)
@@ -1118,7 +1184,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
                 pf_dim_exprs+[pf_dim_expr], pf_idx_subst_map,
                 new_impl_domain)
 
-        return For(
+        return wrap_with(For,
                 "int %s = 0" % pf_dim_var,
                 "%s < %s" % (pf_dim_var, ccm(dim_length)),
                 "++%s" % pf_dim_var,
@@ -1127,8 +1193,10 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         # }}}
 
 
-def generate_prefetch_code(ccm, kernel, sched_index, implemented_domain):
-    from cgen import (Block, Statement as S, Line, Comment)
+def generate_prefetch_code(cgs, kernel, sched_index, implemented_domain):
+    from cgen import Statement as S, Line, Comment
+
+    ccm = cgs.c_code_mapper
 
     # find surrounding schedule items
     if sched_index-1 >= 0:
@@ -1242,7 +1310,7 @@ def generate_prefetch_code(ccm, kernel, sched_index, implemented_domain):
 
     # }}}
 
-    new_block = Block([
+    new_block = [
             Comment(("prefetch %s dim: " % pf.input_vector) 
                 + ", ".join("%s -> %s"
                     % (pf_iname, 
@@ -1252,7 +1320,7 @@ def generate_prefetch_code(ccm, kernel, sched_index, implemented_domain):
                     for pf_iname, realiz_inames in zip(pf.inames, realization_inames)
                     )),
             Line(),
-            ])
+            ]
 
     # omit head sync primitive if we just came out of a prefetch
     if not isinstance(next_outer_sched_item, PrefetchDescriptor):
@@ -1274,62 +1342,136 @@ def generate_prefetch_code(ccm, kernel, sched_index, implemented_domain):
             "no sync needed"))
 
     new_block.extend([Line(),
-        build_loop_nest(ccm, kernel, sched_index+1, implemented_domain)])
+        build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)])
 
-    return new_block
+    return gen_code_block(new_block)
 
 # }}}
 
 # {{{ per-axis loop nest code generation
 
-def generate_loop_dim_code(ccm, kernel, sched_index, 
+def generate_loop_dim_code(cgs, kernel, sched_index, 
         implemented_domain):
     from cgen import (POD, Block, Initializer,
-            For, If, Line, Comment, add_comment)
+            For, Line, Comment, add_comment,
+            make_multiple_ifs)
+
+    ccm = cgs.c_code_mapper
 
     space = implemented_domain.get_dim()
 
     iname = kernel.schedule[sched_index].iname
-    lb_cns, ub_cns = kernel.get_bounds_constraints(iname)
-    lb_cns = cast_constraint_to_space(lb_cns, space)
-    ub_cns = cast_constraint_to_space(ub_cns, space)
-
-    if 0:
-        # FIXME jostle the constant to see if we can get a full slab
-        # test via slab.is_subset(...)
-
-        unconstrained_slab_found = False
-        for lower_incr, upper_incr in [
-                (0,0), 
-                #(0,-1), (1,0), (1,-1)
-                ]:
-            slab_start = start+start_incr
-            slab_stop = stop+stop_incr
-            print slab_start, slab_stop
-            print "SLAB", slab
-            slab_intersection = current_domain.intersect(slab)
-            if has_non_slab_constraints(iname, set, slab_start, slab_stop):
-                pass
+    lb_cns_orig, ub_cns_orig = kernel.get_bounds_constraints(iname)
+    lb_cns_orig = cast_constraint_to_space(lb_cns_orig, space)
+    ub_cns_orig = cast_constraint_to_space(ub_cns_orig, space)
+
+    # jostle the constant in {lb,ub}_cns to see if we can get
+    # fewer conditionals in the bulk middle segment
+
+    class TrialRecord(Record):
+        pass
 
-    loop_slab = (isl.Set.universe(kernel.space)
-            .add_constraint(lb_cns)
-            .add_constraint(ub_cns))
+    if cgs.try_slab_partition:
+        trial_cgs = cgs.copy(try_slab_partition=False)
+        trials = []
+
+        if "outer" in iname:
+            variants = [ (0,0), (0,-1), ]
+        else:
+            variants = [(0,0)]
+
+        for lower_incr, upper_incr in variants:
+
+            lb_cns = block_shift_constraint(lb_cns_orig, iname, -lower_incr)
+            ub_cns = block_shift_constraint(ub_cns_orig, iname, -upper_incr)
+
+            bulk_slab = (isl.Set.universe(kernel.space)
+                    .add_constraint(lb_cns)
+                    .add_constraint(ub_cns))
+            bulk_impl_domain = implemented_domain.intersect(bulk_slab)
+            if not bulk_impl_domain.is_empty():
+                inner = build_loop_nest(trial_cgs, kernel, sched_index+1, 
+                        bulk_impl_domain)
+
+                trials.append((TrialRecord(
+                    lower_incr=lower_incr,
+                    upper_incr=upper_incr,
+                    bulk_slab=bulk_slab),
+                    (inner.num_conditionals,
+                        # when all num_conditionals are equal, choose min increments
+                        abs(upper_incr)+abs(lower_incr))))
+
+        from pytools import argmin2
+        chosen = argmin2(trials)
+    else:
+        bulk_slab = (isl.Set.universe(kernel.space)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(ub_cns_orig))
+        chosen = TrialRecord(
+                    lower_incr=0,
+                    upper_incr=0,
+                    bulk_slab=bulk_slab)
+
+    slabs = []
+    if chosen.lower_incr:
+        slabs.append(isl.Set.universe(kernel.space)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(
+                    negate_constraint(
+                        block_shift_constraint(
+                            lb_cns_orig, iname, -chosen.lower_incr))))
+
+    slabs.append(chosen.bulk_slab)
+
+    if chosen.upper_incr:
+        slabs.append(isl.Set.universe(kernel.space)
+                .add_constraint(ub_cns_orig)
+                .add_constraint(
+                    negate_constraint(
+                        block_shift_constraint(
+                            ub_cns_orig, iname, -chosen.upper_incr))))
+
+    result = []
+    nums_of_conditionals = []
+
+    for slab in slabs:
+        new_impl_domain = implemented_domain.intersect(slab)
+        inner = build_loop_nest(cgs, kernel, sched_index+1, 
+                new_impl_domain)
 
-    new_impl_domain = implemented_domain.intersect(loop_slab)
+        tag = kernel.iname_to_tag.get(iname)
+        start, stop = get_bounds(slab, iname)
+
+        # FIXME what about equality constraints
+        if tag is None:
+            # regular loop
+            result.append(wrap_with(For,
+                    "int %s = %s" % (iname, ccm(start)),
+                    "%s < %s" % (iname, ccm(stop)),
+                    "++%s" % iname, 
+                    inner))
+        else:
+            # parallel loop
+            nums_of_conditionals.append(inner.num_conditionals)
+            if chosen.lower_incr == 0 and chosen.upper_incr == 0:
+                assert len(slabs) == 1
+                return inner
+            else:
+                result.append(
+                        ("%s <= %s && %s < %s"
+                            % (ccm(start), iname, iname, ccm(stop)), 
+                            inner.ast))
 
-    tag = kernel.iname_to_tag.get(iname)
     if tag is None:
         # regular loop
-        start, stop = kernel.get_bounds(iname)
-        return For(
-                "int %s = %s" % (iname, ccm(start)),
-                "%s < %s" % (iname, ccm(stop)),
-                "++%s" % iname, 
-                build_loop_nest(ccm, kernel, sched_index+1,
-                    new_impl_domain))
+        return gen_code_block(result)
     else:
-        return build_loop_nest(ccm, kernel, sched_index+1, 
-                new_impl_domain)
+        # parallel loop
+        return GeneratedCode(
+                ast=make_multiple_ifs(result, base="last"),
+                num_conditionals=min(nums_of_conditionals))
+
+
 
 # }}}
 
@@ -1347,10 +1489,6 @@ def get_valid_index_vars(kernel, sched_index, exclude_tags=()):
 
 def wrap_in_bounds_checks(ccm, domain, valid_index_vars, implemented_domain, stmt):
     from cgen import If
-    have_too_much = not implemented_domain.subtract(domain).is_empty()
-
-    if not have_too_much:
-        return stmt
 
     domain_bsets = []
     domain.foreach_basic_set(domain_bsets.append)
@@ -1373,7 +1511,7 @@ def wrap_in_bounds_checks(ccm, domain, valid_index_vars, implemented_domain, stm
     projected_domain_bset.foreach_constraint(examine_constraint)
 
     if necessary_constraints:
-        stmt = If(
+        stmt = wrap_with(If,
                 "\n&& ".join(
                     "(%s >= 0)" % ccm(constraint_to_expr(cns))
                     for cns in necessary_constraints),
@@ -1385,9 +1523,10 @@ def wrap_in_bounds_checks(ccm, domain, valid_index_vars, implemented_domain, stm
 
 # {{{ codegen top-level dispatch
 
-def build_loop_nest(ccm, kernel, sched_index, implemented_domain):
-    from cgen import (POD, Block, Initializer, Assign, Statement as S,
-            block_if_necessary)
+def build_loop_nest(cgs, kernel, sched_index, implemented_domain):
+    ccm = cgs.c_code_mapper
+
+    from cgen import (POD, Initializer, Assign, Statement as S)
 
     if sched_index >= len(kernel.schedule):
         # {{{ write innermost loop body
@@ -1403,37 +1542,37 @@ def build_loop_nest(ccm, kernel, sched_index, implemented_domain):
 
         return wrap_in_bounds_checks(ccm, kernel.domain, 
                 get_valid_index_vars(kernel, sched_index), 
-                implemented_domain, block_if_necessary(insns))
+                implemented_domain, gen_code_block(insns))
 
         # }}}
 
     sched_item = kernel.schedule[sched_index]
 
     if isinstance(sched_item, ScheduledLoop):
-        return generate_loop_dim_code(ccm, kernel, sched_index, 
+        return generate_loop_dim_code(cgs, kernel, sched_index, 
                 implemented_domain)
 
     elif isinstance(sched_item, WriteOutput):
-        return Block(
+        return gen_code_block(
                 [Initializer(POD(kernel.arg_dict[lvalue.aggregate.name].dtype,
                     "tmp_"+lvalue.aggregate.name), 0)
                     for lvalue, expr in kernel.instructions]
-                +[build_loop_nest(ccm, kernel, sched_index+1, implemented_domain)]+
+                +[build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)]+
                 [wrap_in_bounds_checks(ccm, kernel.domain, 
                     get_valid_index_vars(kernel, sched_index),
                     implemented_domain,
-                    block_if_necessary([
+                    gen_code_block([
                         Assign(
                             ccm(lvalue),
                             "tmp_"+lvalue.aggregate.name)
                         for lvalue, expr in kernel.instructions]))])
 
     elif isinstance(sched_item, PrefetchDescriptor):
-        return generate_prefetch_code(ccm, kernel, sched_index, implemented_domain)
+        return generate_prefetch_code(cgs, kernel, sched_index, implemented_domain)
 
     elif isinstance(sched_item, RegisterPrefetch):
         agg_name = sched_item.subscript_expr.aggregate.name
-        return Block([
+        return gen_code_block([
             wrap_in_bounds_checks(ccm, kernel, sched_index, implemented_domain,
                 Initializer(POD(kernel.arg_dict[agg_name].dtype,
                     sched_item.new_name),
@@ -1441,7 +1580,8 @@ def build_loop_nest(ccm, kernel, sched_index, implemented_domain):
                     % (agg_name,
                         ccm(sched_item.subscript_expr.index)))),
 
-            build_loop_nest(ccm, kernel, sched_index+1, implemented_domain)])
+            build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)
+            ])
 
     else:
         raise ValueError("invalid schedule item encountered")
@@ -1450,6 +1590,9 @@ def build_loop_nest(ccm, kernel, sched_index, implemented_domain):
 
 # {{{ main code generation entrypoint
 
+class CodeGenerationState(Record):
+    __slots__ = ["c_code_mapper", "try_slab_partition"]
+
 def generate_code(kernel):
     from cgen import (FunctionBody, FunctionDeclaration, \
             POD, Value, RestrictPointer, ArrayOf, Module, Block,
@@ -1511,7 +1654,7 @@ def generate_code(kernel):
                 arg_decl = Const(arg_decl)
             arg_decl = CLGlobal(arg_decl)
         else:
-            arg_decl = POD(arg.dtype, arg.name)
+            arg_decl = Const(POD(arg.dtype, arg.name))
 
         if arg.dtype in [np.float64, np.complex128]:
             has_double = True
@@ -1558,9 +1701,10 @@ def generate_code(kernel):
 
     # }}}
 
-    body.extend([
-        Line(),
-        build_loop_nest(ccm, kernel, 0, isl.Set.universe(kernel.space))])
+    cgs = CodeGenerationState(c_code_mapper=ccm, try_slab_partition=True)
+    gen_code = build_loop_nest(cgs, kernel, 0, isl.Set.universe(kernel.space))
+    body.extend([Line(), gen_code.ast])
+    print "# conditionals: %d" % gen_code.num_conditionals
 
     mod.append(
         FunctionBody(
-- 
GitLab