From a439e0097b0e5d2dc7b0c99df7266f3daf16d909 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Oct 2011 23:51:38 -0400
Subject: [PATCH] First passing test for new-style loopy!

---
 MEMO                         |   5 ++
 loopy/__init__.py            | 108 +++++++----------------------------
 loopy/codegen/__init__.py    |  17 +++++-
 loopy/codegen/bounds.py      |  19 +-----
 loopy/codegen/instruction.py |  12 ++--
 loopy/codegen/loop.py        |   5 +-
 loopy/compiled.py            |   5 +-
 loopy/isl.py                 |  68 +++++++++++++---------
 loopy/kernel.py              |  54 +++++++++++-------
 loopy/schedule.py            |  75 +++++++++++++++++++++++-
 test/test_matmul.py          |  10 +---
 11 files changed, 205 insertions(+), 173 deletions(-)

diff --git a/MEMO b/MEMO
index 20f7895ec..dd7220c4b 100644
--- a/MEMO
+++ b/MEMO
@@ -70,9 +70,14 @@ Things to consider
   -> Only reduction
 
 - Slab decomposition for parallel dimensions
+  - implement at the outermost nesting level regardless
+  - bound *all* tagged inames
 
 - Sharing of checks across ILP instances
 
+- Loop bounds currently may not depend on parallel dimensions
+  Does it make sense to relax this?
+
 Dealt with
 ^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index add477149..6f1c9b930 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -79,7 +79,6 @@ def split_dimension(kernel, iname, inner_length, padded_length=None,
                 .project_out(name_dim_type, name_idx, 1))
 
     new_domain = process_set(kernel.domain)
-    new_assumptions = process_set(kernel.assumptions)
 
     from pymbolic import var
     inner = var(inner_iname)
@@ -121,7 +120,6 @@ def split_dimension(kernel, iname, inner_length, padded_length=None,
     iname_slab_increments[outer_iname] = outer_slab_increments
     result = (kernel
             .copy(domain=new_domain,
-                assumptions=new_assumptions,
                 iname_slab_increments=iname_slab_increments,
                 iname_to_dim=None,
                 instructions=new_insns))
@@ -380,7 +378,6 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
         return result, new_iname_to_dim
 
     new_domain, new_iname_to_dim = realize_duplication(kernel.domain)
-    new_assumptions, _ = realize_duplication(kernel.assumptions)
 
     # }}}
 
@@ -397,7 +394,7 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
         from loopy.symbolic import pw_aff_to_expr
 
         target_var_shape.append(static_max_of_pw_aff(
-                upper_bound_pw_aff - lower_bound_pw_aff + 1))
+                upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True))
         target_var_base_indices.append(pw_aff_to_expr(lower_bound_pw_aff))
 
     from loopy.kernel import TemporaryVariable
@@ -417,7 +414,6 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
 
     return kernel.copy(
             domain=new_domain,
-            assumptions=new_assumptions,
             instructions=new_insns,
             temporary_variables=new_temporary_variables,
             iname_to_dim=new_iname_to_dim,
@@ -426,81 +422,7 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
 
 
 
-def realize_reduction(kernel, inames=None, reduction_tag=None):
-    new_insns = []
-    new_temporary_variables = kernel.temporary_variables.copy()
-
-    def map_reduction(expr, rec):
-        sub_expr = rec(expr.expr)
-
-        if reduction_tag is not None and expr.tag != reduction_tag:
-            return
-
-        if inames is not None and set(inames) != set(expr.inames):
-            return
-
-        from pymbolic import var
-
-        target_var_name = kernel.make_unique_var_name("red",
-                extra_used_vars=set(tv for tv in new_temporary_variables))
-        target_var = var(target_var_name)
-
-        from loopy.kernel import Instruction
-
-        from loopy.kernel import TemporaryVariable
-        new_temporary_variables[target_var_name] = TemporaryVariable(
-                name=target_var_name,
-                dtype=expr.operation.dtype,
-                shape=(),
-                is_local=False)
-
-        init_insn = Instruction(
-                id=kernel.make_unique_instruction_id(
-                    extra_used_ids=set(ni.id for ni in new_insns)),
-                assignee=target_var,
-                forced_iname_deps=list(insn.all_inames() - set(expr.inames)),
-                expression=expr.operation.neutral_element)
-
-        new_insns.append(init_insn)
-
-        reduction_insn = Instruction(
-                id=kernel.make_unique_instruction_id(
-                    extra_used_ids=set(ni.id for ni in new_insns)),
-                assignee=target_var,
-                expression=expr.operation(target_var, sub_expr),
-                insn_deps=[init_insn.id],
-                forced_iname_deps=list(insn.all_inames()))
-
-        new_insns.append(reduction_insn)
-
-        new_insn_insn_deps.append(reduction_insn.id)
-        new_insn_removed_inames.extend(expr.inames)
-
-        return target_var
-
-    from loopy.symbolic import ReductionCallbackMapper
-    cb_mapper = ReductionCallbackMapper(map_reduction)
-
-    for insn in kernel.instructions:
-        new_insn_insn_deps = []
-        new_insn_removed_inames = []
-
-        new_expression = cb_mapper(insn.expression)
-
-        new_insns.append(
-                insn.copy(
-                    expression=new_expression,
-                    insn_deps=insn.insn_deps
-                        + new_insn_insn_deps))
-
-    return kernel.copy(
-            instructions=new_insns,
-            temporary_variables=new_temporary_variables)
-
-
-
-
-def get_problems(kernel, parameters, emit_warnings=True):
+def get_problems(kernel, parameters):
     """
     :return: *(max_severity, list of (severity, msg))*, where *severity* ranges from 1-5.
         '5' means 'will certainly not run'.
@@ -508,15 +430,9 @@ def get_problems(kernel, parameters, emit_warnings=True):
     msgs = []
 
     def msg(severity, s):
-        if emit_warnings:
-            from warnings import warn
-            from loopy import LoopyAdvisory
-            warn(s, LoopyAdvisory)
-
         msgs.append((severity, s))
 
-    glens = kernel.tag_type_lengths(TAG_GROUP_IDX, allow_parameters=True)
-    llens = kernel.tag_type_lengths(TAG_LOCAL_IDX, allow_parameters=False)
+    glens, llens = kernel.get_grid_sizes_as_exprs()
 
     from pymbolic import evaluate
     glens = evaluate(glens, parameters)
@@ -534,6 +450,7 @@ def get_problems(kernel, parameters, emit_warnings=True):
     if product(llens) > kernel.device.max_work_group_size:
         msg(5, "work group too big")
 
+    import pyopencl as cl
     from pyopencl.characterize import usable_local_mem_size
     if kernel.local_mem_use() > usable_local_mem_size(kernel.device):
         if kernel.device.local_mem_type == cl.device_local_mem_type.LOCAL:
@@ -554,6 +471,23 @@ def get_problems(kernel, parameters, emit_warnings=True):
         max_severity = max(sev, max_severity)
     return max_severity, msgs
 
+
+
+
+def check_kernels(kernel_gen, parameters, kill_level_min=3,
+        warn_level_min=1):
+    for kernel in kernel_gen:
+        max_severity, msgs = get_problems(kernel, parameters)
+
+        for severity, msg in msgs:
+            if severity >= warn_level_min:
+                from warnings import warn
+                from loopy import LoopyAdvisory
+                warn(msg, LoopyAdvisory)
+
+        if max_severity < kill_level_min:
+            yield kernel
+
 # }}}
 
 # {{{ high-level modifiers
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 65a07c6d0..b15178dc0 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -217,6 +217,9 @@ def generate_code(kernel):
             ( (a) - ( ((a)<0) ? ((b)-1) : 0 )  ) / (b) \
             )
 
+        #define lid(N) ((int) get_local_id(N))
+        #define gid(N) ((int) get_group_id(N))
+
         """),
         Line()])
 
@@ -243,15 +246,23 @@ def generate_code(kernel):
 
     from loopy.codegen.dispatch import build_loop_nest
 
+    from islpy import align_spaces
+    initial_implemented_domain = align_spaces(kernel.assumptions, kernel.domain)
     gen_code = build_loop_nest(kernel, 0,
-            CodeGenerationState(kernel.assumptions, c_code_mapper=ccm))
-    body.extend([Line(), gen_code.ast])
+            CodeGenerationState(initial_implemented_domain, c_code_mapper=ccm))
+
+    body.append(Line())
+
+    if isinstance(gen_code.ast, Block):
+        body.extend(gen_code.ast.contents)
+    else:
+        body.append(gen_code.ast)
 
     from loopy.symbolic import pw_aff_to_expr
     mod.append(
         FunctionBody(
             CLRequiredWorkGroupSize(
-                tuple(pw_aff_to_expr(sz) for sz in kernel.fix_grid_sizes()[1]),
+                tuple(pw_aff_to_expr(sz) for sz in kernel.get_grid_sizes()[1]),
                 CLKernel(FunctionDeclaration(
                     Value("void", kernel.name), args))),
             body))
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index 0b97179b5..06bef83e7 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -31,10 +31,6 @@ def get_bounds_constraints(set, iname, admissible_inames, allow_parameters):
 
     bset, = basic_sets
 
-    # FIXME: hackety hack--elimination leaves the set in an 
-    # invalid ('non-final'?) state
-    bset = bset.intersect(isl.BasicSet.universe(bset.get_space()))
-
     # FIXME perhaps use some form of hull here if there's more than one
     # basic set?
 
@@ -156,11 +152,6 @@ def generate_bounds_checks(domain, check_vars, implemented_domain):
             .coalesce()
             .get_basic_sets())
 
-    # FIXME: hackety hack--elimination leaves the set in an 
-    # invalid ('non-final'?) state
-    domain_bset = domain_bset.intersect(
-            isl.BasicSet.universe(domain_bset.get_space()))
-
     return filter_necessary_constraints(
             implemented_domain, domain_bset.get_constraints())
 
@@ -248,7 +239,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
 
 # {{{ on which variables may a conditional depend?
 
-def get_defined_inames(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
+def get_defined_inames(kernel, sched_index, allow_tag_classes=()):
     """
     :param exclude_tags: a tuple of tag classes to exclude
     """
@@ -264,16 +255,10 @@ def get_defined_inames(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
         elif isinstance(sched_item, LeaveLoop):
             result.remove(sched_item.iname)
 
-    from loopy.kernel import TAG_ILP, ParallelTagWithAxis
     for iname in kernel.all_inames():
         tag = kernel.iname_to_tag.get(iname)
 
-        if isinstance(tag, exclude_tag_classes):
-            continue
-
-        if isinstance(tag, ParallelTagWithAxis):
-            result.add(iname)
-        elif isinstance(tag, TAG_ILP) and allow_ilp:
+        if isinstance(tag, allow_tag_classes):
             result.add(iname)
 
     return result
diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py
index 5651a24d5..e9ddcf2c0 100644
--- a/loopy/codegen/instruction.py
+++ b/loopy/codegen/instruction.py
@@ -57,11 +57,11 @@ def generate_ilp_instances(kernel, insn, codegen_state):
         tag = kernel.iname_to_tag.get(iname)
 
         if isinstance(tag, TAG_LOCAL_IDX):
-            hw_axis_expr = var("(int) get_local_id")(tag.axis)
+            hw_axis_expr = var("lid")(tag.axis)
             hw_axis_size = local_size[tag.axis]
 
         elif isinstance(tag, TAG_GROUP_IDX):
-            hw_axis_expr = var("(int) get_group_id")(tag.axis)
+            hw_axis_expr = var("gid")(tag.axis)
             hw_axis_size = global_size[tag.axis]
 
         else:
@@ -70,12 +70,12 @@ def generate_ilp_instances(kernel, insn, codegen_state):
         bounds = kernel.get_iname_bounds(iname)
 
         from loopy.isl import make_slab
-        impl_domain = impl_domain.intersect(
-                make_slab(impl_domain.get_space(), iname,
-                    bounds.lower_bound_pw_aff, bounds.lower_bound_pw_aff+hw_axis_size))
+        slab = make_slab(impl_domain.get_space(), iname,
+                bounds.lower_bound_pw_aff, bounds.lower_bound_pw_aff+hw_axis_size)
+        impl_domain = impl_domain.intersect(slab)
 
         from loopy.symbolic import pw_aff_to_expr
-        assignments[iname] = pw_aff_to_expr(bounds.lower_bound_pw_aff + hw_axis_expr)
+        assignments[iname] = pw_aff_to_expr(bounds.lower_bound_pw_aff) + hw_axis_expr
 
     # }}} 
 
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 0260ca858..e667a9d76 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -17,7 +17,7 @@ def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
     lower_constraints_orig, upper_constraints_orig, equality_constraints_orig = \
             get_bounds_constraints(kernel.domain, iname,
                     frozenset([iname])
-                    | frozenset(get_defined_inames(kernel, sched_index+1, allow_ilp=False)),
+                    | frozenset(get_defined_inames(kernel, sched_index+1)),
                     allow_parameters=True)
 
     assert not equality_constraints_orig
@@ -27,6 +27,9 @@ def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
 
     return lb_cns_orig, ub_cns_orig
 
+
+
+
 # {{{ conditional-minimizing slab decomposition
 
 def get_slab_decomposition(kernel, sched_index, exec_domain):
diff --git a/loopy/compiled.py b/loopy/compiled.py
index 864683bdd..ed80d2371 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -49,10 +49,7 @@ class CompiledKernel:
             self.size_args = size_args
 
         from loopy.kernel import TAG_GROUP_IDX, TAG_LOCAL_IDX
-        gsize_expr = tuple(self.kernel.tag_type_lengths(
-            TAG_GROUP_IDX, allow_parameters=True))
-        lsize_expr = tuple(self.kernel.tag_type_lengths(
-            TAG_LOCAL_IDX, allow_parameters=False))
+        gsize_expr, lsize_expr = kernel.get_grid_sizes_as_exprs()
 
         if not gsize_expr: gsize_expr = (1,)
         if not lsize_expr: lsize_expr = (1,)
diff --git a/loopy/isl.py b/loopy/isl.py
index fac42ecf6..ef1b6cf51 100644
--- a/loopy/isl.py
+++ b/loopy/isl.py
@@ -90,65 +90,81 @@ def pw_aff_to_aff(pw_aff):
     assert isinstance(pw_aff, isl.PwAff)
     pieces = pw_aff.get_pieces()
 
-    if len(pieces) != 1:
-        raise NotImplementedError("only single-piece PwAff instances are supported here")
+    if len(pieces) == 0:
+        raise RuntimeError("PwAff does not have any pieces")
+    if len(pieces) > 1:
+        _, first_aff = pieces[0]
+        for _, other_aff in pieces[1:]:
+            if not first_aff.plain_is_equal(other_aff):
+                raise NotImplementedError("only single-valued piecewise affine "
+                        "expressions are supported here--encountered "
+                        "multi-valued expression '%s'" % pw_aff)
+
+        return first_aff
 
     return pieces[0][1]
 
 
 
 
-def make_slab(space, iname, start, stop):
-    if isinstance(start, isl.PwAff): start = pw_aff_to_aff(start)
-    if isinstance(stop, isl.PwAff): stop = pw_aff_to_aff(stop)
+def dump_local_space(ls):
+    return " ".join("%s: %d" % (dt, ls.dim(getattr(dim_type, dt))) 
+            for dt in dim_type.names)
 
+def make_slab(space, iname, start, stop):
     zero = isl.Aff.zero_on_domain(space)
 
+    from islpy import align_spaces
+    if isinstance(start, isl.PwAff):
+        start = align_spaces(pw_aff_to_aff(start), zero)
+    if isinstance(stop, isl.PwAff):
+        stop = align_spaces(pw_aff_to_aff(stop), zero)
+
     if isinstance(start, int): start = zero + start
     if isinstance(stop, int): stop = zero + stop
 
     iname_dt, iname_idx = zero.get_space().get_var_dict()[iname]
     iname_aff = zero.add_coefficient(iname_dt, iname_idx, 1)
 
-    return (isl.Set.universe(space)
-            # start <= inner
+    result = (isl.Set.universe(space)
+            # start <= iname
             .add_constraint(isl.Constraint.inequality_from_aff(
                 iname_aff - start))
-            # inner < stop
+            # iname < stop
             .add_constraint(isl.Constraint.inequality_from_aff(
                 stop-1 - iname_aff)))
 
+    return result
 
 
 
-def set_is_universe(set):
-    bs = set.get_basic_sets()
-    if len(bs) == 1:
-        return bs[0].is_universe()
-    else:
-        return isl.Set.universe_like(set).is_subset(set)
-
 
+def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what):
+    pieces = pw_aff.get_pieces()
+    if len(pieces) == 1:
+        return pieces[0][1]
 
+    agg_domain = pw_aff.get_aggregate_domain()
+    for set, candidate_aff in pieces:
+        if constants_only and not candidate_aff.is_cst():
+            continue
 
-def static_min_of_pw_aff(pw_aff):
-    for set, candidate_aff in pw_aff.get_pieces():
-        if set_is_universe(candidate_aff.le_set(pw_aff)):
+        if set_method(pw_aff, candidate_aff) == agg_domain:
             return candidate_aff
 
-    raise ValueError("a static minimum was not found for PwAff '%s'"
-            % pw_aff)
+    raise ValueError("a static %s was not found for PwAff '%s'"
+            % (what, pw_aff))
 
 
 
 
-def static_max_of_pw_aff(pw_aff):
-    for set, candidate_aff in pw_aff.get_pieces():
-        if set_is_universe(candidate_aff.ge_set(pw_aff)):
-            return candidate_aff
+def static_min_of_pw_aff(pw_aff, constants_only):
+    return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.ge_set,
+            "minimum")
 
-    raise ValueError("a static maximum was not found for PwAff '%s'"
-            % pw_aff)
+def static_max_of_pw_aff(pw_aff, constants_only):
+    return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.le_set,
+            "maximum")
 
 
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 1dcfd43cd..762a4c4b6 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -468,16 +468,17 @@ class LoopKernel(Record):
         # }}}
 
         if assumptions is None:
-            assumptions = isl.Set.universe(domain.get_space())
+            assumptions_space = domain.get_space()
+            assumptions_space = assumptions_space.remove_dims(
+                    dim_type.set, 0, assumptions_space.dim(dim_type.set))
+            assumptions = isl.Set.universe(assumptions_space)
         elif isinstance(assumptions, str):
             s = domain.get_space()
             assumptions = isl.BasicSet.read_from_str(domain.get_ctx(),
-                    "[%s] -> {[%s]: %s}"
-                    % (",".join(s.get_name(dim_type.param, i)
-                        for i in range(s.size(dim_type.param))),
-                       ",".join(s.get_name(dim_type.set, i) 
-                           for i in range(s.size(dim_type.set))),
-                       assumptions))
+                    "[%s] -> { : %s}"
+                    % (",".join(s.get_dim_name(dim_type.param, i)
+                        for i in range(s.dim(dim_type.param))),
+                        assumptions))
 
         Record.__init__(self,
                 device=device,  domain=domain, instructions=insns,
@@ -548,8 +549,8 @@ class LoopKernel(Record):
         if self.args is None:
             return []
         else:
-            loop_arg_names = [self.space.get_name(dim_type.param, i)
-                    for i in range(self.space.size(dim_type.param))]
+            loop_arg_names = [self.space.get_dim_name(dim_type.param, i)
+                    for i in range(self.space.dim(dim_type.param))]
             return [arg.name for arg in self.args if isinstance(arg, ScalarArg)
                     if arg.name in loop_arg_names]
 
@@ -565,22 +566,26 @@ class LoopKernel(Record):
     @memoize_method
     def get_iname_bounds(self, iname):
         lower_bound_pw_aff = (self.domain
+                .intersect(self.assumptions)
                 .dim_min(self.iname_to_dim[iname][1])
                 .coalesce())
         upper_bound_pw_aff = (self.domain
+                .intersect(self.assumptions)
                 .dim_max(self.iname_to_dim[iname][1])
                 .coalesce())
 
         class BoundsRecord(Record):
             pass
 
-        size = upper_bound_pw_aff - lower_bound_pw_aff + 1
+        size = (upper_bound_pw_aff - lower_bound_pw_aff + 1)
+        size = size.intersect_domain(self.assumptions)
 
         return BoundsRecord(
                 lower_bound_pw_aff=lower_bound_pw_aff,
                 upper_bound_pw_aff=upper_bound_pw_aff,
                 size=size)
 
+    @memoize_method
     def get_grid_sizes(self):
         all_inames_by_insns = set()
         for insn in self.instructions:
@@ -605,30 +610,28 @@ class LoopKernel(Record):
             elif isinstance(tag, TAG_LOCAL_IDX):
                 tgt_dict = local_sizes
             elif isinstance(tag, TAG_AUTO_LOCAL_IDX):
-                #raise RuntimeError("cannot find grid sizes if AUTO_LOCAL_IDX tags are "
-                        #"present")
-                pass
-                tgt_dict = None
+                raise RuntimeError("cannot find grid sizes if AUTO_LOCAL_IDX tags are "
+                        "present")
             else:
                 tgt_dict = None
 
             if tgt_dict is None:
                 continue
 
-            bounds = self.get_iname_bounds(iname)
+            size = self.get_iname_bounds(iname).size
 
-            size = bounds.size
+            if tag.axis in tgt_dict:
+                size = tgt_dict[tag.axis].max(size)
 
             from loopy.isl import static_max_of_pw_aff
             try:
-                size = static_max_of_pw_aff(size)
+                # insist block size is constant
+                size = static_max_of_pw_aff(size, 
+                        constants_only=isinstance(tag, TAG_LOCAL_IDX))
             except ValueError:
                 pass
 
-            if tag.axis in tgt_dict:
-                tgt_dict[tag.axis] = tgt_dict[tag.axis].max(size)
-            else:
-                tgt_dict[tag.axis] = size
+            tgt_dict[tag.axis] = size
 
         max_dims = self.device.max_work_item_dimensions
 
@@ -655,6 +658,15 @@ class LoopKernel(Record):
         return (to_dim_tuple(global_sizes, "global"),
                 to_dim_tuple(local_sizes, "local"))
 
+    def get_grid_sizes_as_exprs(self):
+        grid_size, group_size = self.get_grid_sizes()
+
+        def tup_to_exprs(tup):
+            from loopy.symbolic import pw_aff_to_expr
+            return tuple(pw_aff_to_expr(i) for i in tup)
+
+        return tup_to_exprs(grid_size), tup_to_exprs(group_size)
+
     def local_mem_use(self):
         return sum(lv.nbytes for lv in self.temporary_variables.itervalues()
                 if lv.is_local)
diff --git a/loopy/schedule.py b/loopy/schedule.py
index 7371b5ae8..f761bc1c9 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -26,6 +26,80 @@ class Barrier(Record):
 
 
 
+def realize_reduction(kernel, inames=None, reduction_tag=None):
+    new_insns = []
+    new_temporary_variables = kernel.temporary_variables.copy()
+
+    def map_reduction(expr, rec):
+        sub_expr = rec(expr.expr)
+
+        if reduction_tag is not None and expr.tag != reduction_tag:
+            return
+
+        if inames is not None and set(inames) != set(expr.inames):
+            return
+
+        from pymbolic import var
+
+        target_var_name = kernel.make_unique_var_name("acc",
+                extra_used_vars=set(tv for tv in new_temporary_variables))
+        target_var = var(target_var_name)
+
+        from loopy.kernel import Instruction
+
+        from loopy.kernel import TemporaryVariable
+        new_temporary_variables[target_var_name] = TemporaryVariable(
+                name=target_var_name,
+                dtype=expr.operation.dtype,
+                shape=(),
+                is_local=False)
+
+        init_insn = Instruction(
+                id=kernel.make_unique_instruction_id(
+                    extra_used_ids=set(ni.id for ni in new_insns)),
+                assignee=target_var,
+                forced_iname_deps=list(insn.all_inames() - set(expr.inames)),
+                expression=expr.operation.neutral_element)
+
+        new_insns.append(init_insn)
+
+        reduction_insn = Instruction(
+                id=kernel.make_unique_instruction_id(
+                    extra_used_ids=set(ni.id for ni in new_insns)),
+                assignee=target_var,
+                expression=expr.operation(target_var, sub_expr),
+                insn_deps=[init_insn.id],
+                forced_iname_deps=list(insn.all_inames()))
+
+        new_insns.append(reduction_insn)
+
+        new_insn_insn_deps.append(reduction_insn.id)
+        new_insn_removed_inames.extend(expr.inames)
+
+        return target_var
+
+    from loopy.symbolic import ReductionCallbackMapper
+    cb_mapper = ReductionCallbackMapper(map_reduction)
+
+    for insn in kernel.instructions:
+        new_insn_insn_deps = []
+        new_insn_removed_inames = []
+
+        new_expression = cb_mapper(insn.expression)
+
+        new_insns.append(
+                insn.copy(
+                    expression=new_expression,
+                    insn_deps=insn.insn_deps
+                        + new_insn_insn_deps))
+
+    return kernel.copy(
+            instructions=new_insns,
+            temporary_variables=new_temporary_variables)
+
+
+
+
 def check_double_use_of_hw_dimensions(kernel):
     from loopy.kernel import UniqueTag
 
@@ -602,7 +676,6 @@ def insert_parallel_dim_check_points(kernel, schedule):
 
 
 def generate_loop_schedules(kernel):
-    from loopy import realize_reduction
     kernel = realize_reduction(kernel)
 
     check_double_use_of_hw_dimensions(kernel)
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 4c51518b2..13a81e0ab 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -211,7 +211,7 @@ def test_plain_matrix_mul_new_ui(ctx_factory):
                 lp.ArrayArg("c", dtype, shape=(n, n), order=order),
                 lp.ScalarArg("n", np.int32, approximately=n),
                 ],
-            name="matmul")
+            name="matmul", assumptions="n >= 1")
 
     knl = lp.split_dimension(knl, "i", 16,
             outer_tag="g.0", inner_tag="l.1", no_slabs=True)
@@ -222,12 +222,8 @@ def test_plain_matrix_mul_new_ui(ctx_factory):
     knl = lp.realize_cse(knl, "lhsmat", dtype, ["k_inner", "i_inner"])
     knl = lp.realize_cse(knl, "rhsmat", dtype, ["j_inner", "k_inner"])
 
-    #print
-    #for insn in knl.instructions:
-        #print insn
-    #assert lp.get_problems(knl, {})[0] <= 2
-
     kernel_gen = lp.generate_loop_schedules(knl)
+    kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=6)
 
     a = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order)
     b = make_well_conditioned_dev_matrix(queue, n, dtype=dtype, order=order)
@@ -235,7 +231,7 @@ def test_plain_matrix_mul_new_ui(ctx_factory):
     refsol = np.dot(a.get(), b.get())
 
     def launcher(kernel, gsize, lsize, check):
-        evt = kernel(queue, gsize(), lsize(), a.data, b.data, c.data,
+        evt = kernel(queue, gsize(n), lsize(n), a.data, b.data, c.data, n,
                 g_times_l=True)
 
         if check:
-- 
GitLab