From b7682eaf440832bdfb18f7e00b86daf99656f688 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 22 Jul 2011 19:00:07 -0500
Subject: [PATCH] Padding of work item axes. A mockup of DG. 'edit' flag.

---
 examples/matrix-ops.py | 108 ++++++++++++++--------
 loopy/__init__.py      | 198 +++++++++++++++++++++++++----------------
 2 files changed, 194 insertions(+), 112 deletions(-)

diff --git a/examples/matrix-ops.py b/examples/matrix-ops.py
index e1f81b878..a317fa532 100644
--- a/examples/matrix-ops.py
+++ b/examples/matrix-ops.py
@@ -10,10 +10,12 @@ import loopy as lp
 FAST_OPTIONS = ["-cl-mad-enable", "-cl-fast-relaxed-math", 
         "-cl-no-signed-zeros", "-cl-strict-aliasing"]
 
-def make_well_conditioned_dev_matrix(queue, n, dtype=np.float32, order="C", ran_factor=1, od=0):
+def make_well_conditioned_dev_matrix(queue, shape, dtype=np.float32, order="C", ran_factor=1, od=0):
+    if isinstance(shape, int):
+        shape = (shape, shape)
     ary = np.asarray(
-        ran_factor*np.random.randn(n, n)
-        + 5*np.eye(n, k=od),
+        ran_factor*np.random.randn(*shape)
+        + 5*np.eye(max(shape), k=od)[:shape[0], :shape[1]],
         dtype=dtype, order=order)
 
     return cl_array.to_device(queue, ary)
@@ -36,7 +38,7 @@ DEBUG_PREAMBLE = r"""
 def check_error(refsol, sol):
     rel_err = la.norm(refsol-sol, "fro")/la.norm(refsol, "fro")
     if DO_CHECK and rel_err > 1e-5:
-        if 0:
+        if 1:
             import matplotlib.pyplot as pt
             pt.imshow(refsol-sol)
             pt.colorbar()
@@ -108,61 +110,93 @@ def plain_matrix_mul(ctx_factory=cl.create_some_context):
 
 
 
-def image_matrix_mul(ctx_factory=cl.create_some_context):
+def dg_matrix_mul(ctx_factory=cl.create_some_context):
     dtype = np.float32
     ctx = ctx_factory()
     order = "C"
     queue = cl.CommandQueue(ctx,
             properties=cl.command_queue_properties.PROFILING_ENABLE)
 
-    n = 16*100
+    Np = 84
+    Np_padded = 96
+    K = 20000
+    dim = 3
+    num_flds = 6
+
     from pymbolic import var
-    a, b, c, i, j, k, n_sym = [var(s) for s in "abcijkn"]
+    fld = var("fld")
+    matrix_names = ["d%d" % i for i in range(dim)]
+    i, j, k = [var(s) for s in "i j k".split()]
+
+    fld_strides = (1, Np_padded)
 
     knl = lp.LoopKernel(ctx.devices[0],
-            "{[i,j,k]: 0<=i,j,k<%d}" % n,
+            "{[i,j,k]: 0<=i,j< %d and 0<=k<%d}" % (Np, K),
             [
-                (c[i, j], a[i, k]*b[k, j])
+                (var(mn+"fld%d" % ifld)[i, k], 
+                    var(mn)[i, j]*var("fld%d" % ifld)[j, k])
+                for mn in matrix_names
+                for ifld in range(num_flds)
                 ],
-            [
-                lp.ImageArg("a", dtype, 2),
-                lp.ImageArg("b", dtype, 2),
-                #lp.ArrayArg("a", dtype, shape=(n, n), order=order),
-                #lp.ArrayArg("b", dtype, shape=(n, n), order=order),
-                lp.ArrayArg("c", dtype, shape=(n, n), order=order),
+            [lp.ImageArg(mn, dtype, 2) for mn in matrix_names]
+            + [lp.ArrayArg("fld%d" % ifld, dtype,
+                strides=fld_strides)
+                for ifld in range(num_flds)
+                ]
+            + [lp.ArrayArg(mn+"fld%d" % ifld, dtype,
+                strides=fld_strides)
+                for ifld in range(num_flds)
+                for mn in matrix_names
                 ],
-            name="matmul")
-
-    knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.1")
-    knl = lp.split_dimension(knl, "j", 16, outer_tag="g.1", inner_tag="l.0")
-    knl = lp.split_dimension(knl, "k", 32)
-    # conflict-free
-    knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
-    knl = lp.add_prefetch(knl, 'b', ["j_inner", "k_inner"])
+            name="dg_matmul")
+
+    knl = lp.split_dimension(knl, "i", 30, 32, outer_tag="g.0", inner_tag="l.0")
+    knl = lp.split_dimension(knl, "k", 16, outer_tag="g.1", inner_tag="l.1")
+    assert Np % 2 == 0
+    #knl = lp.split_dimension(knl, "j", Np//2)
+    #knl = lp.split_dimension(knl, "k", 32)
+
+    #for mn in matrix_names:
+        #knl = lp.add_prefetch(knl, mn, ["j", "i_inner"])
+    for ifld in range(num_flds):
+        knl = lp.add_prefetch(knl, 'fld%d' % ifld, ["k_inner", "j"])
     assert knl.get_invalid_reason() is None
 
-    kernel_gen = (lp.insert_register_prefetches(knl)
-            for knl in lp.generate_loop_schedules(knl))
+    kernel_gen = list(lp.insert_register_prefetches(knl)
+            for knl in lp.generate_loop_schedules(knl))[:1]
 
-    a = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order)
-    b = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order)
-    c = cl_array.empty_like(a)
-    refsol = np.dot(a.get(), b.get())
-    a_img = cl.image_from_array(ctx, a.get(), 1)
-    b_img = cl.image_from_array(ctx, b.get(), 1)
+    matrices = [
+            make_well_conditioned_dev_matrix(queue, Np, dtype=dtype, order="C")
+            for mn in matrix_names]
+    flds = [
+            make_well_conditioned_dev_matrix(queue, (Np_padded, K), dtype=dtype, order="F")
+            for ifld in range(num_flds)]
+    outputs = [cl_array.empty_like(flds[0])
+            for ifld in range(num_flds)
+            for mn in matrix_names]
+
+    ref_soln = [np.dot(mat.get(), fld.get()[:Np]) 
+            for fld in flds
+            for mat in matrices]
+
+    mat_images = [
+            cl.image_from_array(ctx, mat.get(), 1) for mat in matrices]
 
     def launcher(kernel, gsize, lsize, check):
-        evt = kernel(queue, gsize(), lsize(), a_img, b_img, c.data,
-                g_times_l=True)
+        args = mat_images + [fld.data for fld in flds] + [out.data for out in outputs]
+        kwargs = dict(g_times_l=True)
+        evt = kernel(queue, gsize(), lsize(), *args, g_times_l=True)
 
         if check:
-            check_error(refsol, c.get())
+            for out, ref in zip(outputs, ref_soln):
+                check_error(ref, out.get()[:Np])
 
         return evt
 
-    lp.drive_timing_run(kernel_gen, queue, launcher, 2*n**3,
+    lp.drive_timing_run(kernel_gen, queue, launcher, num_flds*dim*2*(Np**2)*K,
             options=FAST_OPTIONS + ["-cl-nv-verbose"],
-            force_rebuild=True)
+            force_rebuild=True, edit=True
+            )
 
 
 
@@ -197,6 +231,8 @@ def fancy_matrix_mul(ctx_factory=cl.create_some_context):
     knl = lp.split_dimension(knl, "k", 16)
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
     knl = lp.add_prefetch(knl, 'b', ["k_inner", "j_inner"])
+    knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
+    knl = lp.add_prefetch(knl, 'b', ["k_inner", "j_inner"])
     assert knl.get_invalid_reason() is None
 
     kernel_gen = (lp.insert_register_prefetches(knl)
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 07c2450a6..d863c43ca 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -22,22 +22,28 @@ register_mpz_with_pymbolic()
 
 
 
+# TODO: ILP Unroll
 
-# TODO: Try, fix reg. prefetch
-# TODO: Divisibility
+# TODO: Try, fix reg. prefetch (DG example) / CSEs
+# TODO: Custom reductions per red. axis
 # TODO: Functions
 # TODO: Common subexpressions
+# TODO: Parse ops from string
+# TODO: Why u take so long?
+
+# TODO: Condition hoisting
+# TODO: Don't emit spurious barriers (no for scheduled before)
+# TODO: Make code more readable
+
+# TODO: Split into multiple files.
+# TODO: Divisibility
 # TODO: Try different kernels
 # TODO:   - Tricky: Convolution, FD
 # TODO: Try, fix indirect addressing
-# TODO: ILP Unroll
 # TODO: User controllable switch for slab opt
 # TODO: User control over schedule
-# TODO: Condition hoisting
 # TODO: Separate all-bulk from non-bulk kernels.
 
-# TODO: Custom reductions per red. axis
-
 # TODO Tim: implement efficient div_ceil?
 # TODO Tim: why are corner cases inefficient?
 
@@ -49,38 +55,43 @@ class LoopyAdvisory(UserWarning):
 
 # {{{ index tags
 
-class IndexTag(object):
-    def __init__(self, axis=None):
-        self.axis = axis
+class IndexTag(Record):
+    __slots__ = []
 
-    def __eq__(self, other):
-        return (self.__class__ == other.__class__
-                and self.axis == other.axis)
+    def __hash__(self):
+        raise RuntimeError("use .key to hash index tags")
 
-    def __ne__(self, other):
-        return not self.__eq__(other)
+    @property
+    def key(self):
+        return type(self)
 
-    def __hash__(self):
-        return hash(type(self)) ^ hash(self.axis)
 
 
-class ImplicitlyParallelTag(IndexTag):
-    pass
+class AxisParallelTag(IndexTag):
+    __slots__ = ["axis", "forced_length"]
 
-class TAG_GROUP_IDX(ImplicitlyParallelTag):
-    def __repr__(self):
-        if self.axis is None:
-            return "GROUP_IDX"
-        else:
-            return "GROUP_IDX(%d)" % self.axis
+    def __init__(self, axis, forced_length=None):
+        Record.__init__(self,
+                axis=axis, forced_length=forced_length)
 
+    @property
+    def key(self):
+        return (type(self), self.axis)
 
-class TAG_WORK_ITEM_IDX(ImplicitlyParallelTag):
     def __repr__(self):
-        if self.axis is None:
-            return "WORK_ITEM_IDX"
+        if self.forced_length:
+            return "%s(%d, flen=%d)" % (
+                    self.print_name, self.axis,
+                    self.forced_length)
         else:
-            return "WORK_ITEM_IDX(%d)" % self.axis
+            return "%s(%d)" % (
+                    self.print_name, self.axis)
+
+class TAG_GROUP_IDX(AxisParallelTag):
+    print_name = "GROUP_IDX"
+
+class TAG_WORK_ITEM_IDX(AxisParallelTag):
+    print_name = "WORK_ITEM_IDX"
 
 class TAG_ILP_UNROLL(IndexTag):
     def __repr__(self):
@@ -235,6 +246,8 @@ def solve_constraint_for_bound(cns, iname):
         return "<", cfm(flatten(div_ceil(rhs+1, -iname_coeff)))
 
 
+
+
 def get_projected_bounds(set, iname):
     """Get an overapproximation of the loop bounds for the variable *iname*,
     as actual bounds.
@@ -503,17 +516,18 @@ class ScalarArg:
 # {{{ loop kernel object
 
 class LoopKernel(Record):
-    # possible attributes:
-    # - device, a PyOpenCL target device
-    # - domain
-    # - iname_to_tag
-    # - instructions
-    # - args
-    # - prefetch
-    # - schedule
-    # - register_prefetch
-    # - name
-    # - preamble
+    """
+    :ivar device: :class:`pyopencl.Device`
+    :ivar domain: :class:`islpy.BasicSet`
+    :ivar iname_to_tag:
+    :ivar instructions:
+    :ivar args:
+    :ivar prefetch:
+    :ivar schedule:
+    :ivar register_prefetch:
+    :ivar name:
+    :ivar preamble:
+    """
 
     def __init__(self, device, domain, instructions, args=None, prefetch={}, schedule=None,
             register_prefetch=None, name="loopy_kernel",
@@ -559,9 +573,10 @@ class LoopKernel(Record):
 
     @property
     @memoize_method
-    def tag_to_iname(self):
-        from pytools import reverse_dictionary
-        return reverse_dictionary(self.iname_to_tag)
+    def tag_key_to_iname(self):
+        return dict(
+                (tag.key, iname)
+                for iname, tag in self.iname_to_tag.iteritems())
 
     @property
     @memoize_method
@@ -604,7 +619,7 @@ class LoopKernel(Record):
         from itertools import count
         for i in count():
             try:
-                dim = self.tag_to_iname[tag_type(i)]
+                dim = self.tag_key_to_iname[tag_type(i).key]
             except KeyError:
                 return result
             else:
@@ -622,12 +637,17 @@ class LoopKernel(Record):
 
         return get_projected_bounds(self.domain, iname)
 
-    def tag_type_bounds(self, tag_cls):
-        return [self.get_projected_bounds(iname)
-                for iname in self.ordered_inames_by_tag_type(tag_cls)]
-
     def tag_type_lengths(self, tag_cls):
-        return [stop-start for start, stop in self.tag_type_bounds(tag_cls)]
+        def get_length(iname):
+            tag = self.iname_to_tag[iname]
+            if tag.forced_length is not None:
+                return tag.forced_length
+
+            start, stop = self.get_projected_bounds(iname)
+            return stop-start
+
+        return [get_length(iname)
+                for iname in self.ordered_inames_by_tag_type(tag_cls)]
 
     def tag_or_iname_to_iname(self, s):
         try:
@@ -635,7 +655,7 @@ class LoopKernel(Record):
         except ValueError:
             pass
         else:
-            return self.tag_to_iname[tag]
+            return self.tag_key_to_iname[tag.key]
 
         if s not in self.all_inames():
             raise RuntimeError("invalid index name '%s'" % s)
@@ -715,20 +735,33 @@ class LoopKernel(Record):
 
         return copy
 
-    def split_dimension(self, name, inner_length, outer_name=None, inner_name=None,
+    def split_dimension(self, name, inner_length, padded_length=None,
+            outer_name=None, inner_name=None,
             outer_tag=None, inner_tag=None):
 
         outer_tag = parse_tag(outer_tag)
         inner_tag = parse_tag(inner_tag)
 
-        new_tags = set(tag for tag in [outer_tag, inner_tag] if tag is not None)
-
         if self.iname_to_tag.get(name) is not None:
             raise RuntimeError("cannot split tagged dimension '%s'" % name)
 
-        repeated_tags = new_tags & set(self.iname_to_tag.values())
-        if repeated_tags:
-            raise RuntimeError("repeated tag(s): %s" % repeated_tags)
+        # {{{ check for repeated tag keys
+
+        new_tag_keys = set(tag.key
+                for tag in [outer_tag, inner_tag]
+                if tag is not None)
+
+        repeated_tag_keys = new_tag_keys & set(
+                tag.key for tag in
+                self.iname_to_tag.itervalues())
+
+        if repeated_tag_keys:
+            raise RuntimeError("repeated tag(s): %s" % repeated_tag_keys)
+
+        # }}}
+
+        if padded_length is not None:
+            inner_tag = inner_tag.copy(forced_length=padded_length)
 
         if outer_name is None:
             outer_name = name+"_outer"
@@ -945,9 +978,8 @@ def generate_loop_schedules(kernel):
             # in this branch. at least one of its loop dimensions
             # was already scheduled, and that dimension is not
             # borrowable.
-            print "UNSCHEDULABLE:"
-            print_kernel_info(kernel)
-            raw_input()
+
+            #print "UNSCHEDULABLE", kernel.schedule
             return
 
         new_kernel = kernel.copy(schedule=prev_schedule+[pf])
@@ -1627,15 +1659,24 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
                         block_shift_constraint(
                             ub_cns_orig, iname, -chosen.upper_incr)))))
 
-    if isinstance(tag, ImplicitlyParallelTag):
+    if isinstance(tag, AxisParallelTag):
         # For a parallel loop dimension, the global loop bounds are
         # automatically obeyed--simply because no work items are launched
         # outside the requested grid.
-
-        implemented_domain = implemented_domain.intersect(
-                isl.Set.universe(kernel.space)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(ub_cns_orig))
+        #
+        # For a forced length, this is actually implemented
+        # by an if below.
+
+        if tag.forced_length is None:
+            implemented_domain = implemented_domain.intersect(
+                    isl.Set.universe(kernel.space)
+                    .add_constraint(lb_cns_orig)
+                    .add_constraint(ub_cns_orig))
+        else:
+            impl_len = tag.forced_length
+            start, _ = kernel.get_projected_bounds(iname)
+            implemented_domain = implemented_domain.intersect(
+                    make_slab(kernel.space, iname, start, start+impl_len))
 
     result = []
     nums_of_conditionals = []
@@ -1668,11 +1709,17 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
     if tag is None:
         # regular or unrolled loop
         return gen_code_block(result)
-    elif isinstance(tag, ImplicitlyParallelTag):
+
+    elif isinstance(tag, AxisParallelTag):
         # parallel loop
+        if tag.forced_length is None:
+            base = "last"
+        else:
+            base = None
         return GeneratedCode(
-                ast=make_multiple_ifs(result, base="last"),
+                ast=make_multiple_ifs(result, base=base),
                 num_conditionals=min(nums_of_conditionals))
+
     else:
         assert False, "we aren't supposed to get here"
 
@@ -1938,11 +1985,8 @@ def generate_code(kernel):
 
     body = Block()
 
-    group_size = kernel.tag_type_lengths(TAG_WORK_ITEM_IDX)
-
     # {{{ examine arg list
 
-
     def restrict_ptr_if_not_nvidia(arg):
         from cgen import Pointer, RestrictPointer
 
@@ -2073,12 +2117,12 @@ def get_input_access_descriptors(kernel):
     from pytools import flatten
     result = {}
     for ivec in kernel.input_vectors():
-        result[ivec] = [
+        result[ivec] = set(
                 (ivec, iexpr)
                 for iexpr in flatten(
                     VariableIndexExpressionCollector(ivec)(expression)
                     for lvalue, expression in kernel.instructions
-                    )]
+                    ))
 
     return result
 
@@ -2123,7 +2167,7 @@ def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
 
 class CompiledKernel:
     def __init__(self, context, kernel, size_args=None, options=[],
-            force_rebuild=False):
+            force_rebuild=False, edit=False):
         self.kernel = kernel
         self.code = generate_code(kernel)
 
@@ -2131,8 +2175,9 @@ class CompiledKernel:
             from time import time
             self.code = "/* %s */\n%s" % (time(), self.code)
 
-        #from pytools import invoke_editor
-        #self.code = invoke_editor(self.code)
+        if edit:
+            from pytools import invoke_editor
+            self.code = invoke_editor(self.code)
 
         try:
             self.cl_kernel = getattr(
@@ -2178,7 +2223,8 @@ class CompiledKernel:
 
 # {{{ timing driver
 def drive_timing_run(kernel_generator, queue, launch, flop_count=None,
-        options=[], print_code=True, force_rebuild=False):
+        options=[], print_code=True, force_rebuild=False,
+        edit=False):
 
     def time_run(compiled_knl, warmup_rounds=2, timing_rounds=5):
         check = True
@@ -2203,7 +2249,7 @@ def drive_timing_run(kernel_generator, queue, launch, flop_count=None,
     for kernel in kernel_generator:
 
         compiled = CompiledKernel(queue.context, kernel, options=options,
-                force_rebuild=force_rebuild)
+                force_rebuild=force_rebuild, edit=edit)
 
         print "-----------------------------------------------"
         print "SOLUTION #%d" % soln_count
-- 
GitLab