diff --git a/test/test_fem_assembly.py b/test/test_fem_assembly.py
index 55243ee02b657127bb5bc197722b68dce5b2c1f1..5695d735a5e95cad373506183d35d47a867d09da 100644
--- a/test/test_fem_assembly.py
+++ b/test/test_fem_assembly.py
@@ -15,32 +15,30 @@ def test_laplacian_stiffness(ctx_factory):
     ctx = ctx_factory()
     order = "C"
 
-    dim = 2
+    dim = 2 # (baked into code)
 
-    Nq = 40 # num. quadrature points
-    Nc = 100 # num. cells
-    Nb = 20 # num. basis functions
-
-    # K - run-time symbolic
+    Nq = 40 # num. quadrature points (baked into code)
+    Nb = 20 # num. basis functions (baked into code)
+    Nc = 100 # num. cells (run-time symbolic)
 
     from pymbolic import var
     Nc_sym = var("Nc")
 
     knl = lp.make_kernel(ctx.devices[0],
-            "[Nc] -> {[K,i,j,q, ax_a, ax_b]: 0<=K<Nc and 0<=i,j<%(Nb)d and 0<=q<%(Nq)d "
-            "and 0<= ax_a, ax_b < %(dim)d}"
+            "[Nc] -> {[K,i,j,q, dx_axis, ax_b]: 0<=K<Nc and 0<=i,j<%(Nb)d and 0<=q<%(Nq)d "
+            "and 0<= dx_axis, ax_b < %(dim)d}"
             % dict(Nb=Nb, Nq=Nq, dim=dim),
             [
-                "dPsi(a, dxi) := sum_float32(@ax_b,"
-                    "  jacInv[ax_b,dxi,K,q] * DPsi[ax_b,a,q])",
+                "dPsi(ij, dxi) := sum_float32(@ax_b,"
+                    "  jacInv[ax_b,dxi,K,q] * DPsi[ax_b,ij,q])",
                 "A[K, i, j] = sum_float32(q, w[q] * jacDet[K,q] * ("
-                    "sum_float32(ax_a, dPsi(0,ax_a)*dPsi(1,ax_a))))"
+                    "sum_float32(dx_axis, dPsi.one(i,dx_axis)*dPsi.two(j,dx_axis))))"
                 ],
             [
             lp.ArrayArg("jacInv", dtype, shape=(dim, dim, Nc_sym, Nq), order=order),
             lp.ConstantArrayArg("DPsi", dtype, shape=(dim, Nb, Nq), order=order),
             lp.ArrayArg("jacDet", dtype, shape=(Nc_sym, Nq), order=order),
-            lp.ConstantArrayArg("w", dtype, shape=(Nq, dim), order=order),
+            lp.ConstantArrayArg("w", dtype, shape=(Nq,), order=order),
             lp.ArrayArg("A", dtype, shape=(Nc_sym, Nb, Nb), order=order),
             lp.ScalarArg("Nc",  np.int32, approximately=1000),
             ],
@@ -49,42 +47,89 @@ def test_laplacian_stiffness(ctx_factory):
     knl = lp.tag_dimensions(knl, dict(ax_b="unr"))
     seq_knl = knl
 
-    def variant_1(knl):
-        # no ILP across elements
-        knl = lp.split_dimension(knl, "K", 16, outer_tag="g.0", slabs=(0,1))
-        knl = lp.tag_dimensions(knl, {"i": "l.1", "j": "l.0"})
-        knl = lp.add_prefetch(knl, 'jacInv',
-                ["jacInv_dim_0", "jacInv_dim_1", "K_inner", "q"])
-        return knl
-
-    def variant_2(knl):
-        # with ILP across elements
-        knl = lp.split_dimension(knl, "K", 16, outer_tag="g.0", slabs=(0,1))
-        knl = lp.split_dimension(knl, "K_inner", 4, inner_tag="ilp")
+    def variant_fig31(knl):
+        # This (mostly) reproduces Figure 3.1.
+
+        knl = lp.tag_dimensions(knl, {"dx_axis": "unr"})
+        return knl, ["K", "i", "j", "q", "ax_b_insn"]
+
+    def variant_pg4(knl):
+        # This (mostly) reproduces the unlabeled code snippet on pg. 4.
+
+        knl = lp.tag_dimensions(knl, {"dx_axis": "unr"})
+        Ncloc = 16
+        knl = lp.split_dimension(knl, "K", Ncloc,
+                outer_iname="Ko", inner_iname="Kloc")
+        return knl, ["Ko", "Kloc", "i", "j", "q", "ax_b_insn"]
+
+    def variant_fig32(knl):
+        # This (mostly) reproduces Figure 3.2.
+
+        Ncloc = 16
+        knl = lp.split_dimension(knl, "K", Ncloc,
+                outer_iname="Ko", inner_iname="Kloc")
+        knl = lp.precompute(knl, "dPsi", np.float32, ["i", "q", "dx_axis"],
+                default_tag=None)
+        knl = lp.tag_dimensions(knl, {"dx_axis": "unr", "dxi": "unr"})
+        return knl, ["Ko", "Kloc", "dPsi_q", "ij", "i", "j", "q", "ax_b_insn"]
+
+    def variant_fig33(knl):
+        # This is meant to (mostly) reproduce Figure 3.3.
+        # It currently doesn't find a valid schedule. (I'll fix that.)
+        # (FIXME)
+
+        Ncloc = 16
+        knl = lp.split_dimension(knl, "K", Ncloc,
+                outer_iname="Ko", inner_iname="Kloc")
+        knl = lp.precompute(knl, "dPsi.one", np.float32, default_tag=None)
+        #knl = lp.precompute(knl, "dPsi.two", np.float32, default_tag=None)
+        return knl, ["Ko", "Kloc"]
+
+    def variant_simple_gpu(knl):
+        # This is a simple GPU-ish variant.
+
+        # It's not the same thing as Matt's code, but I'll need some more time
+        # to reverse-engineer what is going on there. Some discussion might
+        # help, too. :)
+
+        knl = lp.tag_dimensions(knl, {"dx_axis": "unr"})
+        Ncloc = 16
+        knl = lp.split_dimension(knl, "K", Ncloc,
+                outer_iname="Ko", inner_iname="Kloc",
+                outer_tag="g.0")
         knl = lp.tag_dimensions(knl, {"i": "l.1", "j": "l.0"})
-        knl = lp.add_prefetch(knl, "jacInv",
-                ["jacInv_dim_0", "jacInv_dim_1", "K_inner_inner", "K_inner_outer", "q"])
-        return knl
+        return knl, ["K", "i", "j", "q", "ax_b_insn"]
 
-    def variant_3(knl):
-        # no ILP across elements, precompute dPsiTransf
+    def variant_simple_gpu_prefetch(knl):
+        # This adds prefetching to the GPU variant above.
 
-        # generates correct code--but suboptimal in a bunch of ways.
+        # In this variant (on my machine), loopy makes a silly choice
+        # for the upper bound of Kloc (it uses Nc). I'll investigate and
+        # fix that. (FIXME)
 
-        knl = lp.split_dimension(knl, "K", 16, outer_tag="g.0", slabs=(0,1))
-        knl = lp.add_prefetch(knl, "jacInv",
-                ["jacInv_dim_0", "jacInv_dim_1", "q"])
+        knl = lp.tag_dimensions(knl, {"dx_axis": "unr"})
+        Ncloc = 16
+        knl = lp.split_dimension(knl, "K", Ncloc,
+                outer_iname="Ko", inner_iname="Kloc",
+                outer_tag="g.0")
         knl = lp.tag_dimensions(knl, {"i": "l.1", "j": "l.0"})
-        knl = lp.precompute(knl, "dPsi", np.float32,
-                sweep_axes=["K_inner"])
-        return knl
-
-    for variant in [variant_1, variant_2, variant_3]:
-    #for variant in [variant_3]:
-        kernel_gen = lp.generate_loop_schedules(variant(knl),
-                loop_priority=["jacInv_dim_0", "jacInv_dim_1"])
+        knl = lp.add_prefetch(knl, "w", ["q"])
+        knl = lp.add_prefetch(knl, "DPsi", [0, 1, 2])
+        knl = lp.add_prefetch(knl, "jacInv", [0, 1, 3])
+        knl = lp.add_prefetch(knl, "jacDet", [1])
+        return knl, ["K", "i", "j", "q", "ax_b_insn"]
+
+    # Plug in variant name here
+    #                        |
+    #                        v
+    for variant in [variant_simple_gpu_prefetch]:
+        var_knl, loop_prio = variant(knl)
+        kernel_gen = lp.generate_loop_schedules(var_knl,
+                loop_priority=loop_prio)
         kernel_gen = lp.check_kernels(kernel_gen, dict(Nc=Nc))
 
+        #print lp.preprocess_kernel(var_knl)
+
         lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
                 op_count=0, op_label="GFlops",
                 parameters={"Nc": Nc}, print_ref_code=True,