From 186aee9c51b94f314b5f01811f4c59cfc1f76e22 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 Jul 2011 16:16:57 -0400
Subject: [PATCH] Further progress on the way to isl loopy.

---
 examples/matrix-ops.py |   4 +-
 loopy/__init__.py      | 352 +++++++++++++++++++----------------------
 2 files changed, 167 insertions(+), 189 deletions(-)

diff --git a/examples/matrix-ops.py b/examples/matrix-ops.py
index abfa7c0c1..9d9bbfb85 100644
--- a/examples/matrix-ops.py
+++ b/examples/matrix-ops.py
@@ -40,8 +40,8 @@ def plain_matrix_mul(ctx_factory=cl.create_some_context):
     knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.1")
     knl = lp.split_dimension(knl, "j", 16, outer_tag="g.1", inner_tag="l.0")
     knl = lp.split_dimension(knl, "k", 16)
-    knl = lp.add_prefetch_dims(knl, 'a', ["i_inner", "k_inner"])
-    knl = lp.add_prefetch_dims(knl, 'b', ["k_inner", "j_inner"])
+    knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
+    knl = lp.add_prefetch(knl, 'b', ["k_inner", "j_inner"])
     assert knl.get_invalid_reason() is None
 
     kernel_gen = (lp.insert_register_prefetches(knl)
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 9b6361ecb..867672c7c 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -118,6 +118,61 @@ def parse_tag(tag):
 
 # }}}
 
+# {{{ pymbolic mappers
+
+class CoefficientCollector(RecursiveMapper):
+    def map_sum(self, expr):
+        stride_dicts = [self.rec(ch) for ch in expr.children]
+
+        result = {}
+        for stride_dict in stride_dicts:
+            for var, stride in stride_dict.iteritems():
+                if var in result:
+                    result[var] += stride
+                else:
+                    result[var] = stride
+
+        return result
+
+    def map_product(self, expr):
+        result = {}
+        for i, ch in enumerate(expr.children):
+            strides = self.rec(ch)
+            from pymbolic import flattened_product
+            prod_other_children = flattened_product(
+                    expr.children[:i] + expr.children[(i+1):])
+
+            for var, stride in strides.iteritems():
+                if var in result:
+                    raise NotImplementedError(
+                            "nonlinear index expression")
+                else:
+                    result[var] = prod_other_children*stride
+
+        return result
+
+    def map_divide(self, expr):
+        num_strides = self.rec(expr.numerator)
+        denom_strides = self.rec(expr.denominator)
+
+        if denom_strides:
+            raise NotImplementedError
+
+        return dict(
+                (var, stride/expr.denominator)
+                    for var, stride in num_strides.iteritems())
+
+    def map_constant(self, expr):
+        return {}
+
+    def map_variable(self, expr):
+        return {expr.name: 1}
+
+    def map_subscript(self, expr):
+        raise RuntimeError("cannot gather coefficients--indirect addressing in use")
+
+# }}}
+
 # {{{ loop dim, loop domain, kernel
 
 class LoopDimension(Record):
@@ -157,60 +212,6 @@ class LoopDimension(Record):
 
 
 
-class LoopDomain(Record):
-    __slots__ = ["dims"]
-
-    def name_to_idx(self, name):
-        for i, dim in enumerate(self.dims):
-            if dim.name == name:
-                return i
-        else:
-            raise KeyError("invalid dimension name: %s" % name)
-
-    def name_to_dim(self, name):
-        for dim in self.dims:
-            if dim.name == name:
-                return dim
-        else:
-            raise KeyError("invalid dimension name: %s" % name)
-
-    def tag_to_idx(self, tag):
-        for i, dim in enumerate(self.dims):
-            if dim.tag == tag:
-                return i
-        raise KeyError("invalid tag: %s" % tag)
-
-    def indices_by_tag_type(self, tag_type):
-        return [i for i, dim in enumerate(self.dims)
-                if isinstance(dim.tag, tag_type)]
-
-    def dims_by_tag_type(self, tag_type):
-        return [dim for dim in self.dims
-                if isinstance(dim.tag, tag_type)]
-
-    def ordered_inames_by_tag_type(self, tag_type):
-        result = []
-        from itertools import count
-        for i in count():
-            try:
-                dim = self.tag_to_iname[tag_type(i)]
-            except KeyError:
-                return result
-            else:
-                result.append(dim)
-
-    def dims_by_tag(self, tag):
-        return [dim for dim in self.dims if dim.tag == tag]
-
-    def set_dim(self, idx, new_dim):
-        return self.copy(dims=
-                self.dims[:idx]
-                + [new_dim]
-                + self.dims[(idx+1):])
-
-
-
-
 # {{{ arguments
 
 class ArrayArg:
@@ -265,7 +266,7 @@ class ScalarArg:
 
 
 
-class LoopKernel(LoopDomain):
+class LoopKernel(Record):
     # possible attributes:
     # - device, a PyOpenCL target device
     # - domain
@@ -304,7 +305,7 @@ class LoopKernel(LoopDomain):
                     parse_if_necessary(expr))
                 for lvalue, expr in instructions]
 
-        LoopDomain.__init__(self,
+        Record.__init__(self,
                 device=device, args=args, domain=domain, instructions=insns,
                 prefetch=prefetch, schedule=schedule,
                 register_prefetch=register_prefetch, name=name,
@@ -339,42 +340,46 @@ class LoopKernel(LoopDomain):
             return set(arg.name for arg in self.args if isinstance(arg, ScalarArg))
 
     @memoize_method
-    def all_indices(self):
+    def all_inames(self):
         return set(self.space.get_var_dict(dim_type.set).iterkeys())
 
     @memoize_method
-    def output_indices(self):
+    def output_inames(self):
         dm = DependencyMapper(include_subscripts=False)
 
         output_indices = set()
         for lvalue, expr in self.instructions:
             output_indices.update(
                     set(v.name for v in dm(lvalue))
-                    & self.all_indices())
+                    & self.all_inames())
 
         return output_indices - set(arg.name for arg in self.args)
 
     @memoize_method
-    def output_dimensions(self):
-        return [dim for dim in self.dims if dim.name in self.output_indices()]
+    def reduction_inames(self):
+        return self.all_inames() - self.output_inames()
+
+    def inames_by_tag_type(self, tag_type):
+        return [iname for iname in self.all_inames()
+                if isinstance(self.iname_to_tag.get(iname), tag_type)]
+
+    def ordered_inames_by_tag_type(self, tag_type):
+        result = []
+        from itertools import count
+        for i in count():
+            try:
+                dim = self.tag_to_iname[tag_type(i)]
+            except KeyError:
+                return result
+            else:
+                result.append(dim)
 
-    @memoize_method
-    def reduction_dimensions(self):
-        return [dim for dim in self.dims if dim.name not in self.output_indices()]
 
     def get_bounds(self, iname):
         """Get an overapproximation of the loop bounds for the variable *iname*."""
 
-        iname_dim_type, iname_idx = self.space.get_var_dict()[iname]
-        assert iname_dim_type == dim_type.set
-
         # project out every variable except iname
-        projected_domain = (self.domain
-                # vars after iname
-                .project_out(
-                    iname_dim_type, iname_idx+1, self.space.size(iname_dim_type)-iname_idx-1)
-                .project_out(
-                    iname_dim_type, 0, iname_idx))
+        projected_domain = isl.project_out_except(self.domain, [iname], [dim_type.set])
 
         basic_sets = []
         projected_domain.foreach_basic_set(basic_sets.append)
@@ -387,19 +392,21 @@ class LoopKernel(LoopDomain):
         upper_bounds = []
         lower_bounds = []
         bset = bset.remove_divs()
-        print bset
 
         bset_iname_dim_type, bset_iname_idx = bset.get_dim().get_var_dict()[iname]
 
+        from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
+        from pymbolic import flatten
+        cfm = CommutativeConstantFoldingMapper()
+
         def examine_constraint(cns):
             coeffs = cns.get_coefficients_by_name()
-            print coeffs
 
             iname_coeff = int(coeffs.get(iname, 0))
             if iname_coeff == 0:
                 return
 
-            rhs = cns.get_constant()
+            rhs = int(cns.get_constant())
             from pymbolic import var
             for var_name, coeff in coeffs.iteritems():
                 if var_name == iname:
@@ -408,19 +415,37 @@ class LoopKernel(LoopDomain):
 
             if iname_coeff < 0:
                 from pytools import div_ceil
-                upper_bounds.append(div_ceil(rhs, -iname_coeff))
+                upper_bounds.append(cfm(flatten(div_ceil(rhs+1, -iname_coeff))))
             else: #  iname_coeff > 0
-                lower_bounds.append(rhs//iname_coeff)
+                lower_bounds.append(cfm(flatten(rhs//iname_coeff)))
 
         bset.foreach_constraint(examine_constraint)
 
         lb, = lower_bounds
         ub, = upper_bounds
-        print iname, lb, ub
-        print "WRONG: Inclusive bounds!"
 
         return lb, ub
 
+    def address_map(self, index_expr):
+        if not isinstance(index_expr, tuple):
+            index_expr = (self.index_expr,)
+
+        coeff_coll = CoefficientCollector()
+        all_coeffs = tuple(coeff_coll(iexpr_i) for iexpr_i in index_expr)
+
+        amap = isl.Map.from_domain(self.domain).add_dims(dim_type.out, len(index_expr))
+        out_names = ["_ary_idx_%d" % i for i in range(len(index_expr))]
+
+        for i, out_name in enumerate(out_names):
+            amap = amap.set_dim_name(dim_type.out, i, out_name)
+
+        for i, (out_name, coeffs) in enumerate(zip(out_names, all_coeffs)):
+            coeffs[out_name] = -1
+            amap = amap.add_constraint(isl.Constraint.eq_from_names(
+                amap.get_dim(), 0, coeffs))
+
+        return amap
+
     def tag_type_bounds(self, tag_cls):
         return [self.get_bounds(iname)
                 for iname in self.ordered_inames_by_tag_type(tag_cls)]
@@ -440,12 +465,15 @@ class LoopKernel(LoopDomain):
         else:
             return self.tag_to_iname[tag]
 
-        if s not in self.all_indices():
+        if s not in self.all_inames():
             raise RuntimeError("invalid index name '%s'" % s)
 
         return s
 
     def local_mem_use(self):
+        from warnings import warn
+        warn("local_mem_use unimpl")
+        return 0
         return sum(pf.size() for pf in self.prefetch.itervalues())
 
     @memoize_method
@@ -457,7 +485,7 @@ class LoopKernel(LoopDomain):
             input_vectors.update(
                     set(v.name for v in dm(expr)))
 
-        return input_vectors - self.all_indices() - self.scalar_args()
+        return input_vectors - self.all_inames() - self.scalar_args()
 
     @memoize_method
     def output_vectors(self):
@@ -468,7 +496,7 @@ class LoopKernel(LoopDomain):
             output_vectors.update(
                     set(v.name for v in dm(lvalue)))
 
-        return output_vectors - self.all_indices() - self.scalar_args()
+        return output_vectors - self.all_inames() - self.scalar_args()
 
     def _subst_insns(self, old_var, new_expr):
         from pymbolic.mapper.substitutor import substitute
@@ -577,18 +605,18 @@ class LoopKernel(LoopDomain):
                 .copy(domain=new_domain, iname_to_tag=new_iname_to_tag))
 
     def get_invalid_reason(self):
-        gdims = self.tag_type_count(TAG_GROUP_IDX)
-        ldims = self.tag_type_count(TAG_WORK_ITEM_IDX)
-        1/0
-        if (max(len(gdims), len(ldims))
+        glens = self.tag_type_lengths(TAG_GROUP_IDX)
+        llens = self.tag_type_lengths(TAG_WORK_ITEM_IDX)
+        if (max(len(glens), len(llens))
                 > self.device.max_work_item_dimensions):
             return "too many work item dimensions"
 
-        for i in range(len(ldims)):
-            if ldims[i] > self.device.max_work_item_sizes[i]:
+        for i in range(len(llens)):
+            if llens[i] > self.device.max_work_item_sizes[i]:
                 return "group axis %d too big"
 
-        if self.group_size() > self.device.max_work_group_size:
+        from pytools import product
+        if product(llens) > self.device.max_work_group_size:
             return "work group too big"
 
         from pyopencl.characterize import usable_local_mem_size
@@ -601,6 +629,7 @@ class LoopKernel(LoopDomain):
 
 # {{{ local-mem prefetch-related
 
+
 class PrefetchDescriptor(Record):
     """
     Attributes:
@@ -608,9 +637,9 @@ class PrefetchDescriptor(Record):
     :ivar input_vector: A string indicating the input vector variable name.
     :ivar index_expr: An expression identifying the access which this prefetch
       serves.
-    :ivar dims: A sequence of loop dimensions identifying which part of the
-      input vector, given the index_expr, should be prefetched.
-    :ivar loc_fetch_axes: dictionary from integers 0..len(dims) to lists of
+    :ivar inames: A sequence of inames (i.e. loop dimensions) identifying which
+        part of the input vector, given the index_expr, should be prefetched.
+    :ivar loc_fetch_axes: dictionary from integers 0..len(inames) to lists of
       local index axes which should be used to realize that dimension of the
       prefetch. The last dimension in this list is used as the fastest-changing
       one.
@@ -623,15 +652,21 @@ class PrefetchDescriptor(Record):
     """
 
     def size(self):
-        from pytools import product
-        return (self.kernel.arg_dict[self.input_vector].dtype.itemsize
-                * product(dim.length for dim in self.dims))
+        my_image = (
+                isl.project_out_except(self.kernel.domain, self.inames, [dim_type.set])
+                .remove_divs())
+        assert my_image.is_box()
+
+        print my_image
+        print my_image.is_box()
+        1/0
+
 
     @memoize_method
     def free_variables(self):
         return set(var.name
                 for var in DependencyMapper()(self.index_expr)
-                ) - set(dim.name for dim in self.dims) - self.kernel.scalar_args()
+                ) - set(self.inames) - self.kernel.scalar_args()
 
     def hash(self):
         return (hash(type(self)) ^ hash(self.input_vector)
@@ -674,87 +709,31 @@ class VariableIndexExpressionCollector(CombineMapper):
 
 
 
-class StrideCollector(RecursiveMapper):
-    def __init__(self, arg):
-        self.arg = arg
-
-    def map_sum(self, expr):
-        stride_dicts = [self.rec(ch) for ch in expr.children]
-
-        result = {}
-        for stride_dict in stride_dicts:
-            for var, stride in stride_dict.iteritems():
-                if var in result:
-                    result[var] += stride
-                else:
-                    result[var] = stride
-
-        return result
-
-    def map_product(self, expr):
-        result = {}
-        for i, ch in enumerate(expr.children):
-            strides = self.rec(ch)
-            from pymbolic import flattened_product
-            prod_other_children = flattened_product(
-                    expr.children[:i] + expr.children[(i+1):])
-
-            for var, stride in strides.iteritems():
-                if var in result:
-                    raise NotImplementedError(
-                            "nonlinear index expression")
-                else:
-                    result[var] = prod_other_children*stride
-
-        return result
-
-    def map_divide(self, expr):
-        num_strides = self.rec(expr.numerator)
-        denom_strides = self.rec(expr.denominator)
-
-        if denom_strides:
-            raise NotImplementedError
-
-        return dict(
-                (var, stride/expr.denominator)
-                    for var, stride in num_strides.iteritems())
-
-    def map_constant(self, expr):
-        return {}
-
-    def map_variable(self, expr):
-        return {expr.name: 1}
-
-    def map_tuple(self, expr):
-        return self.rec(sum(
-            stride*expr_i for stride, expr_i in zip(
-                self.arg.strides, expr)))
-
-    def map_subscript(self, expr):
-        raise RuntimeError("cannot gather strides--indirect addressing in use")
-
 # }}}
 
 # {{{ loop scheduling
 
+class ScheduledLoop(Record):
+    __slots__ = ["iname"]
+
 def generate_loop_schedules(kernel):
     prev_schedule = kernel.schedule
     if prev_schedule is None:
-        prev_schedule = (
-            kernel.dims_by_tag_type(TAG_GROUP_IDX)
-            + kernel.dims_by_tag_type(TAG_WORK_ITEM_IDX))
+        prev_schedule = [
+                ScheduledLoop(iname=iname)
+                for iname in (
+                    kernel.ordered_inames_by_tag_type(TAG_GROUP_IDX)
+                    + kernel.ordered_inames_by_tag_type(TAG_WORK_ITEM_IDX))]
 
-    already_scheduled = set(sch_item
+    scheduled_inames = set(sch_item.iname
             for sch_item in prev_schedule
-            if isinstance(sch_item, LoopDimension))
+            if isinstance(sch_item, ScheduledLoop))
 
     # have a schedulable prefetch? load, schedule it
-    scheduled_names = set(dim.name for dim in already_scheduled)
-
     had_usable_prefetch = False
-    scheduled_work_item_dim_names = set(
-            dim.name for dim in already_scheduled
-            if isinstance(dim.tag, TAG_WORK_ITEM_IDX))
+    scheduled_work_item_inames = set(
+            iname for iname in scheduled_inames
+            if isinstance(kernel.iname_to_tag.get(iname), TAG_WORK_ITEM_IDX))
 
     for pf in kernel.prefetch.itervalues():
         # already scheduled? never mind then.
@@ -762,14 +741,13 @@ def generate_loop_schedules(kernel):
             continue
 
         # a free variable not known yet? then we're not ready
-        if not pf.free_variables() <= scheduled_names:
+        if not pf.free_variables() <= scheduled_inames:
             continue
 
         # a prefetch variable already scheduled, but not borrowable?
         # (only work item index variables are borrowable)
-        pf_loop_names = set(dim.name for dim in pf.dims)
 
-        if pf_loop_names & (already_scheduled - scheduled_work_item_dim_names):
+        if set(pf.inames) & (scheduled_inames - scheduled_work_item_inames):
             # dead end: we won't be able to schedule this prefetch
             # in this branch. at least one of its loop dimensions
             # was already scheduled, and that dimension is not
@@ -788,19 +766,17 @@ def generate_loop_schedules(kernel):
         return
 
     # Build set of potentially schedulable variables
-    schedulable = set(kernel.dims)
-
     # Don't re-schedule already scheduled variables
-    schedulable -= already_scheduled
+    schedulable = kernel.all_inames() - scheduled_inames
 
     # Don't schedule reduction variables until all output
     # variables are taken care of. Once they are, schedule
     # output writing.
-    serial_output_dims = set(od for od in kernel.output_dimensions()
-            if od.tag is None)
+    serial_output_inames = set(oin for oin in kernel.output_inames()
+            if kernel.iname_to_tag.get(oin) is None)
 
-    if not serial_output_dims <= already_scheduled:
-        schedulable -= set(kernel.reduction_dimensions())
+    if not serial_output_inames <= scheduled_inames:
+        schedulable -= kernel.reduction_dimensions()
     else:
         if not any(isinstance(sch_item, WriteOutput)
                 for sch_item in prev_schedule):
@@ -810,29 +786,29 @@ def generate_loop_schedules(kernel):
 
     # Don't schedule variables that are prefetch axes
     # for not-yet-scheduled prefetches.
-    unsched_prefetch_axes = set(dim
+    unsched_prefetch_axes = set(iname
             for pf in kernel.prefetch.itervalues()
             if pf not in prev_schedule
-            for dim in pf.dims)
+            for iname in pf.inames)
     schedulable -= unsched_prefetch_axes
 
     if schedulable:
         # have a schedulable variable? schedule a loop for it, recurse
-        for dim in schedulable:
-            new_kernel = kernel.copy(schedule=prev_schedule+[dim])
+        for iname in schedulable:
+            new_kernel = kernel.copy(schedule=prev_schedule+[ScheduledLoop(iname=iname)])
             for knl in generate_loop_schedules(new_kernel):
                 yield knl
     else:
         # all loop dimensions and prefetches scheduled?
         # great! yield the finished product if it is complete
 
-        all_dims_scheduled = len(already_scheduled) == len(kernel.dims)
+        all_inames_scheduled = len(scheduled_inames) == len(kernel.all_inames())
         all_pf_scheduled =  len(set(sch_item for sch_item in prev_schedule
             if isinstance(sch_item, PrefetchDescriptor))) == len(kernel.prefetch)
         output_scheduled = len(set(sch_item for sch_item in prev_schedule
             if isinstance(sch_item, WriteOutput))) == 1
 
-        if all_dims_scheduled and all_pf_scheduled and output_scheduled:
+        if all_inames_scheduled and all_pf_scheduled and output_scheduled:
             yield kernel
 
 # }}}
@@ -862,7 +838,7 @@ class AllIndexExpressionCollector(CombineMapper):
 def insert_register_prefetches(kernel):
     reg_pf = {}
 
-    total_loop_count = len(kernel.all_indices())
+    total_loop_count = len(kernel.all_inames())
     known_vars = set()
 
     unused_index_exprs = set()
@@ -1325,6 +1301,8 @@ def generate_code(kernel):
 
     new_prefetch = {}
     for i_pf, pf in enumerate(kernel.prefetch.itervalues()):
+        amap = kernel.address_map(pf.index_expr)
+        1/0
         dim_storage_lengths = [pfdim.length for pfdim in pf.dims]
 
         other_pf_sizes = sum(all_pf_sizes[:i_pf]+all_pf_sizes[i_pf+1:])
@@ -1475,12 +1453,12 @@ def get_input_access_descriptors(kernel):
 
     return result
 
-def add_prefetch_dims(kernel, input_access_descr, dims, loc_fetch_axes={}):
+def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
     """
     :arg input_access_descr: see :func:`get_input_access_descriptors`.
         May also be the name of the variable if there is only one
         reference to that variable.
-    :arg dims: loop dimensions that are used to carry out the prefetch
+    :arg tags_or_inames: loop dimensions that are used to carry out the prefetch
     """
 
     if isinstance(input_access_descr, str):
@@ -1493,7 +1471,7 @@ def add_prefetch_dims(kernel, input_access_descr, dims, loc_fetch_axes={}):
 
         input_access_descr, = var_iads
 
-    dims = [kernel.tag_or_iname_to_iname(s) for s in dims]
+    inames = [kernel.tag_or_iname_to_iname(s) for s in tags_or_inames]
     ivec, iexpr = input_access_descr
 
     new_prefetch = getattr(kernel, "prefetch", {}).copy()
@@ -1505,7 +1483,7 @@ def add_prefetch_dims(kernel, input_access_descr, dims, loc_fetch_axes={}):
             kernel=kernel,
             input_vector=ivec,
             index_expr=iexpr,
-            dims=dims,
+            inames=inames,
             loc_fetch_axes={})
 
     return kernel.copy(prefetch=new_prefetch)
-- 
GitLab