diff --git a/MEMO b/MEMO
index 589e672b5427643c07c0dddca61601b9dafdf5a8..1a76b61e040098083629219715aade111f56ad0a 100644
--- a/MEMO
+++ b/MEMO
@@ -62,6 +62,8 @@ Things to consider
 TODO
 ^^^^
 
+- Make axpy better.
+
 - implemented_domain may end up being smaller than requested in cse
   evaluations--check that!
 
@@ -88,6 +90,8 @@ TODO
 Dealt with
 ^^^^^^^^^^
 
+- Screwy lower bounds in slab decomposition
+
 - reimplement add_prefetch
 
 - Flag, exploit idempotence
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 05cc51ecbf5b48a780b0a6612ba04d7d3c0488eb..1b4e057589360d2c663e7cc0fdddf40f093f2c06 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -32,9 +32,6 @@ def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
 # {{{ conditional-minimizing slab decomposition
 
 def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
-    from loopy.isl_helpers import block_shift_constraint, negate_constraint
-
-    ccm = codegen_state.c_code_mapper
     space = kernel.space
     tag = kernel.iname_to_tag.get(iname)
 
@@ -47,35 +44,79 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
     iname_tp, iname_idx = kernel.iname_to_dim[iname]
 
-    slabs = []
-    if lower_incr:
-        slabs.append(("initial", isl.Set.universe(kernel.space)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(ub_cns_orig)
-                .add_constraint(
-                    negate_constraint(
-                        block_shift_constraint(
-                            lb_cns_orig, iname_tp, iname_idx, -lower_incr)))))
-
-    slabs.append(("bulk",
-        (isl.Set.universe(kernel.space)
-            .add_constraint(
-                block_shift_constraint(lb_cns_orig, iname_tp, iname_idx, -lower_incr))
-            .add_constraint(
-                block_shift_constraint(ub_cns_orig, iname_tp, iname_idx, -upper_incr)))))
-
-    if upper_incr:
-        slabs.append(("final", isl.Set.universe(kernel.space)
-                .add_constraint(ub_cns_orig)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(
-                    negate_constraint(
-                        block_shift_constraint(
-                            ub_cns_orig, iname_tp, iname_idx, -upper_incr)))))
+    constraints = [lb_cns_orig]
+    if lower_incr or upper_incr:
+        bounds = kernel.get_iname_bounds(iname)
+
+        lower_bound_pw_aff_pieces = bounds.lower_bound_pw_aff.coalesce().get_pieces()
+        upper_bound_pw_aff_pieces = bounds.upper_bound_pw_aff.coalesce().get_pieces()
+
+        if len(lower_bound_pw_aff_pieces) > 1:
+            raise NotImplementedError("lower bound for slab decomp of '%s' needs "
+                    "conditional/has more than one piece" % iname)
+        if len(upper_bound_pw_aff_pieces) > 1:
+            raise NotImplementedError("upper bound for slab decomp of '%s' needs "
+                    "conditional/has more than one piece" % iname)
+
+        (_, lower_bound_aff), = lower_bound_pw_aff_pieces
+        (_, upper_bound_aff), = upper_bound_pw_aff_pieces
+
+        lower_bulk_bound = lb_cns_orig
+        upper_bulk_bound = lb_cns_orig
+
+        from loopy.isl_helpers import iname_rel_aff
+
+        if lower_incr:
+            assert lower_incr > 0
+            lower_slab = ("initial", isl.Set.universe(kernel.space)
+                    .add_constraint(lb_cns_orig)
+                    .add_constraint(ub_cns_orig)
+                    .add_constraint(
+                        isl.Constraint.inequality_from_aff(
+                            iname_rel_aff(kernel.space,
+                                iname, "<", lower_bound_aff+lower_incr))))
+            lower_bulk_bound = (
+                    isl.Constraint.inequality_from_aff(
+                        iname_rel_aff(kernel.space,
+                            iname, ">=", lower_bound_aff+lower_incr)))
+        else:
+            lower_slab = None
+
+        if upper_incr:
+            assert upper_incr > 0
+            upper_slab = ("final", isl.Set.universe(kernel.space)
+                    .add_constraint(lb_cns_orig)
+                    .add_constraint(ub_cns_orig)
+                    .add_constraint(
+                        isl.Constraint.inequality_from_aff(
+                            iname_rel_aff(kernel.space,
+                                iname, ">=", upper_bound_aff-upper_incr))))
+            upper_bulk_bound = (
+                    isl.Constraint.inequality_from_aff(
+                        iname_rel_aff(kernel.space,
+                            iname, "<", upper_bound_aff-upper_incr)))
+        else:
+            lower_slab = None
+
+        slabs = []
+
+        if lower_slab:
+            slabs.append(lower_slab)
+        slabs.append((
+            ("bulk",
+                (isl.Set.universe(kernel.space)
+                    .add_constraint(lower_bulk_bound)
+                    .add_constraint(upper_bulk_bound)))))
+        if upper_slab:
+            slabs.append(upper_slab)
+
+        return slabs
 
-    # }}}
-
-    return lb_cns_orig, ub_cns_orig, slabs
+    else:
+        return [("bulk",
+            (isl.Set.universe(kernel.space)
+            .add_constraint(lb_cns_orig)
+            .add_constraint(ub_cns_orig)))]
 
 # }}}
 
@@ -166,7 +207,7 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state, hw_inames_left=
 
     # }}}
 
-    lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
+    slabs = get_slab_decomposition(
             kernel, iname, sched_index, codegen_state)
 
     if other_inames_with_same_tag and len(slabs) > 1:
@@ -202,7 +243,7 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
     iname = kernel.schedule[sched_index].iname
     tag = kernel.iname_to_tag.get(iname)
 
-    lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
+    slabs = get_slab_decomposition(
             kernel, iname, sched_index, codegen_state)
 
     result = []
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 1e8490388111c9267294c23149e7e85b352ec186..55504a877c752eda3474ca09f87bfb1fe5f06443 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -109,19 +109,19 @@ def test_axpy(ctx_factory):
                 lp.ArrayArg("z", dtype, shape="n,"),
                 lp.ScalarArg("n", np.int32, approximately=n),
                 ],
-            name="matmul")
+            name="matmul", assumptions="n>=4096")
 
     unroll = 4
     block_size = 256
-    knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, -1))
-    knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0", slabs=(0, -1))
+    knl = lp.split_dimension(knl, "i", unroll*block_size, outer_tag="g.0", slabs=(0, 1))
+    knl = lp.split_dimension(knl, "i_inner", block_size, outer_tag="unr", inner_tag="l.0")
 
     kernel_gen = lp.generate_loop_schedules(knl)
     kernel_gen = lp.check_kernels(kernel_gen, dict(n=n), kill_level_min=5)
 
     a = cl_random.rand(queue, n, dtype=dtype)
     b = cl_random.rand(queue, n, dtype=dtype)
-    c = cl_array.empty_like(a)
+    c = cl_array.zeros_like(a)
     refsol = (2*a+3*b).get()
 
     def launcher(kernel, gsize, lsize, check):