From 3534f2d263bf19987ba033f7b48afd8b88598013 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 27 Aug 2012 02:24:15 -0400
Subject: [PATCH] Allow slice ("a[:,:]") prefetch specs.

---
 MEMO                    |  10 +-
 loopy/__init__.py       | 131 +++++++++++++++++++------
 loopy/cse.py            | 208 +++++++++++++++++++++++-----------------
 loopy/kernel.py         |  33 ++++++-
 test/test_linalg.py     |  34 +++++++
 test/test_nbody.py      |   3 +-
 test/test_sem_reagan.py |   4 +-
 7 files changed, 296 insertions(+), 127 deletions(-)

diff --git a/MEMO b/MEMO
index d9d823b65..5336b8f48 100644
--- a/MEMO
+++ b/MEMO
@@ -41,18 +41,18 @@ Things to consider
 To-do
 ^^^^^
 
+- Kernel splitting (via what variables get computed in a kernel)
+
+Fixes:
+
 - Group instructions by dependency/inames for scheduling, to
   increase sched. scalability
 
-- Kernel splitting (via what variables get computed in a kernel)
-
 - What if no universally valid precompute base index expression is found?
   (test_intel_matrix_mul with n = 6*16, e.g.?)
 
 - If finding a maximum proves troublesome, move parameters into the domain
 
-- : (as in, Matlab full-slice) in prefetches
-
 Future ideas
 ^^^^^^^^^^^^
 
@@ -105,6 +105,8 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- : (as in, Matlab full-slice) in prefetches
+
 - Add dependencies after the fact
 
 - Scalar insn priority
diff --git a/loopy/__init__.py b/loopy/__init__.py
index dc6c10c85..e0b1932b9 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -295,6 +295,93 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
 
 # {{{ convenience: add_prefetch
 
+# {{{ process footprint_subscripts
+
+def _add_kernel_axis(kernel, axis_name, start, stop, base_inames):
+    from loopy.kernel import DomainChanger
+    domch = DomainChanger(kernel, base_inames)
+
+    domain = domch.domain
+    new_dim_idx = domain.dim(dim_type.set)
+    domain = (domain
+            .insert_dims(dim_type.set, new_dim_idx, 1)
+            .set_dim_name(dim_type.set, new_dim_idx, axis_name))
+
+    from loopy.isl_helpers import make_slab
+    slab = make_slab(domain.get_space(), axis_name, start, stop)
+
+    domain = domain & slab
+
+    return kernel.copy(domains=domch.get_domains_with(domain))
+
+def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
+        footprint_subscripts, arg, newly_created_vars):
+    """Track applied iname rewrites, deal with slice specifiers ':'."""
+
+    from pymbolic.primitives import Variable
+
+    if footprint_subscripts is None:
+        return kernel, rule_name, sweep_inames, []
+
+    if not isinstance(footprint_subscripts, (list, tuple)):
+        footprint_subscripts = [footprint_subscripts]
+
+    inames_to_be_removed = []
+
+    new_footprint_subscripts = []
+    for fsub in footprint_subscripts:
+        if isinstance(fsub, str):
+            from loopy.symbolic import parse
+            fsub = parse(fsub)
+
+        if not isinstance(fsub, tuple):
+            fsub = (fsub,)
+
+        if len(fsub) != arg.dimensions:
+            raise ValueError("sweep index '%s' has the wrong number of dimensions")
+
+        for subst_map in kernel.applied_iname_rewrites:
+            from loopy.symbolic import SubstitutionMapper
+            from pymbolic.mapper.substitutor import make_subst_func
+            fsub = SubstitutionMapper(make_subst_func(subst_map))(fsub)
+
+        from loopy.symbolic import get_dependencies
+        fsub_dependencies = get_dependencies(fsub)
+
+        new_fsub = []
+        for axis_nr, fsub_axis in enumerate(fsub):
+            from pymbolic.primitives import Slice
+            if isinstance(fsub_axis, Slice):
+                if fsub_axis.children != (None,):
+                    raise NotImplementedError("add_prefetch only "
+                            "supports full slices")
+
+                axis_name = kernel.make_unique_var_name(
+                        based_on="%s_fetch_axis_%d" % (arg.name, axis_nr),
+                        extra_used_vars=newly_created_vars)
+
+                newly_created_vars.add(axis_name)
+                kernel = _add_kernel_axis(kernel, axis_name, 0, arg.shape[axis_nr],
+                        frozenset(sweep_inames) | fsub_dependencies)
+                sweep_inames = sweep_inames + [axis_name]
+
+                inames_to_be_removed.append(axis_name)
+                new_fsub.append(Variable(axis_name))
+
+            else:
+                new_fsub.append(fsub_axis)
+
+        new_footprint_subscripts.append(tuple(new_fsub))
+        del new_fsub
+
+    footprint_subscripts = new_footprint_subscripts
+    del new_footprint_subscripts
+
+    subst_use = [Variable(rule_name)(*si) for si in footprint_subscripts]
+    return kernel, subst_use, sweep_inames, inames_to_be_removed
+
+# }}}
+
 def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         default_tag="l.auto", rule_name=None, footprint_subscripts=None):
     """Prefetch all accesses to the variable *var_name*, with all accesses
@@ -380,44 +467,32 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
     kernel = extract_subst(kernel, rule_name, uni_template, parameters)
 
-    # {{{ track applied iname rewrites on footprint_subscripts
+    kernel, subst_use, sweep_inames, inames_to_be_removed = \
+            _process_footprint_subscripts(
+                    kernel,  rule_name, sweep_inames,
+                    footprint_subscripts, arg, newly_created_vars)
 
-    if footprint_subscripts is not None:
-        if not isinstance(footprint_subscripts, (list, tuple)):
-            footprint_subscripts = [footprint_subscripts]
-
-        def standardize_footprint_indices(si):
-            if isinstance(si, str):
-                from loopy.symbolic import parse
-                si = parse(si)
+    new_kernel = precompute(kernel, subst_use, arg.dtype, sweep_inames,
+            new_storage_axis_names=dim_arg_names,
+            default_tag=default_tag)
 
-            if not isinstance(si, tuple):
-                si = (si,)
+    # {{{ remove inames that were temporarily added by slice sweeps
 
-            if len(si) != arg.dimensions:
-                raise ValueError("sweep index '%s' has the wrong number of dimensions")
+    new_domains = new_kernel.domains[:]
 
-            for subst_map in kernel.applied_iname_rewrites:
-                from loopy.symbolic import SubstitutionMapper
-                from pymbolic.mapper.substitutor import make_subst_func
-                si = SubstitutionMapper(make_subst_func(subst_map))(si)
+    for iname in inames_to_be_removed:
+        home_domain_index = kernel.get_home_domain_index(iname)
+        domain = new_domains[home_domain_index]
 
-            return si
+        dt, idx = domain.get_var_dict()[iname]
+        assert dt == dim_type.set
 
-        footprint_subscripts = [standardize_footprint_indices(si) for si in footprint_subscripts]
+        new_domains[home_domain_index] = domain.project_out(dt, idx, 1)
 
-        from pymbolic.primitives import Variable
-        subst_use = [
-                Variable(rule_name)(*si) for si in footprint_subscripts]
-    else:
-        subst_use = rule_name
+    new_kernel = new_kernel.copy(domains=new_domains)
 
     # }}}
 
-    new_kernel = precompute(kernel, subst_use, arg.dtype, sweep_inames,
-            new_storage_axis_names=dim_arg_names,
-            default_tag=default_tag)
-
     # If the rule survived past precompute() (i.e. some accesses fell outside
     # the footprint), get rid of it before moving on.
     if rule_name in new_kernel.substitutions:
diff --git a/loopy/cse.py b/loopy/cse.py
index 236057ea3..08341a2b8 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -49,10 +49,11 @@ def to_parameters_or_project_out(param_inames, set_inames, set):
 
 # {{{ construct storage->sweep map
 
-def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
-        storage_axis_names, storage_axis_sources, prime_sweep_inames):
+def build_per_access_storage_to_domain_map(invdesc, domain,
+        storage_axis_names, storage_axis_sources,
+        prime_sweep_inames):
 
-    map_space = domain_dup_sweep.get_space()
+    map_space = domain.get_space()
     stor_dim = len(storage_axis_names)
     rn = map_space.dim(dim_type.out)
 
@@ -63,13 +64,13 @@ def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
 
         map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")
 
-    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep]
+    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)
 
     set_space = map_space.move_dims(
             dim_type.out, rn,
             dim_type.in_, 0, stor_dim).range()
 
-    # set_space: [domain](dup_sweep_index)[dup_sweep][stor_axes']
+    # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes']
 
     stor2sweep = None
 
@@ -102,15 +103,30 @@ def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
     # stor2sweep is back in map_space
     return stor2sweep
 
-def build_global_storage_to_sweep_map(invocation_descriptors,
-        dup_sweep_index, domain_dup_sweep,
-        storage_axis_names, storage_axis_sources, prime_sweep_inames):
+def move_to_par_from_out(s2smap, except_inames):
+    while True:
+        var_dict = s2smap.get_var_dict(dim_type.out)
+        todo_inames = set(var_dict) - except_inames
+        if todo_inames:
+            iname = todo_inames.pop()
+
+            _, dim_idx = var_dict[iname]
+            s2smap = s2smap.move_dims(
+                    dim_type.param, s2smap.dim(dim_type.param),
+                    dim_type.out, dim_idx, 1)
+        else:
+            return s2smap
+
+def build_global_storage_to_sweep_map(kernel, invocation_descriptors,
+        domain_dup_sweep, dup_sweep_index,
+        storage_axis_names, storage_axis_sources,
+        sweep_inames, primed_sweep_inames, prime_sweep_inames):
     """
     As a side effect, this fills out is_in_footprint in the
     invocation descriptors.
     """
 
-    # The storage map goes from storage axes to domain_dup_sweep.
+    # The storage map goes from storage axes to the domain.
     # The first len(arg_names) storage dimensions are the rule's arguments.
 
     global_stor2sweep = None
@@ -118,8 +134,10 @@ def build_global_storage_to_sweep_map(invocation_descriptors,
     # build footprint
     for invdesc in invocation_descriptors:
         if invdesc.expands_footprint:
-            stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
-                    storage_axis_names, storage_axis_sources, prime_sweep_inames)
+            stor2sweep = build_per_access_storage_to_domain_map(
+                    invdesc, domain_dup_sweep,
+                    storage_axis_names, storage_axis_sources,
+                    prime_sweep_inames)
 
             if global_stor2sweep is None:
                 global_stor2sweep = stor2sweep
@@ -132,33 +150,55 @@ def build_global_storage_to_sweep_map(invocation_descriptors,
         global_stor2sweep = isl.Map.from_basic_map(stor2sweep)
     global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
 
-    # function to move non-sweep inames into parameter space
-    def move_non_sweep_to_par(s2smap):
-        sp = s2smap.get_space()
-        return s2smap.move_dims(
-                dim_type.param, sp.dim(dim_type.param),
-                dim_type.out, 0, dup_sweep_index)
 
-    global_s2s_par_dom = move_non_sweep_to_par(global_stor2sweep).domain()
+    # {{{ check if non-footprint-building invocation descriptors fall into footprint
+
+    # Make all inames except the sweep parameters. (The footprint may depend on those.)
+    # (I.e. only leave sweep inames as out parameters.)
+    global_s2s_par_dom = move_to_par_from_out(
+            global_stor2sweep, except_inames=frozenset(primed_sweep_inames)).domain()
 
-    # check if non-footprint-building invocation descriptors fall into footprint
     for invdesc in invocation_descriptors:
-        stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
-                    storage_axis_names, storage_axis_sources, prime_sweep_inames)
+        if not invdesc.expands_footprint:
+            arg_inames = set()
 
-        if isinstance(stor2sweep, isl.BasicMap):
-            stor2sweep = isl.Map.from_basic_map(stor2sweep)
+            for arg in invdesc.args:
+                arg_inames.update(get_dependencies(arg))
+            arg_inames = frozenset(arg_inames)
 
-        stor2sweep = move_non_sweep_to_par(
-                stor2sweep.intersect_range(domain_dup_sweep))
+            usage_domain = kernel.get_inames_domain(arg_inames)
+            for i in xrange(usage_domain.dim(dim_type.set)):
+                iname = usage_domain.get_dim_name(dim_type.set, i)
+                if iname in sweep_inames:
+                    usage_domain = usage_domain.set_dim_name(
+                            dim_type.set, i, iname+"'")
 
-        is_in_footprint = stor2sweep.domain().is_subset(
-                global_s2s_par_dom)
+            stor2sweep = build_per_access_storage_to_domain_map(invdesc,
+                    usage_domain, storage_axis_names, storage_axis_sources,
+                    prime_sweep_inames)
+
+            if isinstance(stor2sweep, isl.BasicMap):
+                stor2sweep = isl.Map.from_basic_map(stor2sweep)
+
+            stor2sweep = stor2sweep.intersect_range(usage_domain)
+
+            stor2sweep = move_to_par_from_out(stor2sweep,
+                    except_inames=frozenset(primed_sweep_inames))
+
+            s2s_domain = stor2sweep.domain()
+            s2s_domain, aligned_g_s2s_parm_dom = isl.align_two(
+                    s2s_domain, global_s2s_par_dom)
+
+            s2s_domain = s2s_domain.project_out_except(
+                    arg_inames, [dim_type.param])
+            aligned_g_s2s_parm_dom = aligned_g_s2s_parm_dom.project_out_except(
+                    arg_inames, [dim_type.param])
+
+            is_in_footprint = s2s_domain.is_subset(aligned_g_s2s_parm_dom)
 
-        if not invdesc.expands_footprint:
             invdesc.is_in_footprint = is_in_footprint
-        else:
-            assert is_in_footprint
+
+    # }}}
 
     return global_stor2sweep
 
@@ -176,18 +216,11 @@ def find_var_base_indices_and_shape_from_inames(
 
 
 
-def compute_bounds(kernel, sweep_domain, subst_name, stor2sweep, sweep_inames,
-        storage_axis_names):
+def compute_bounds(kernel, domain, subst_name, stor2sweep,
+        primed_sweep_inames, storage_axis_names):
 
-    # move non-sweep inames into parameter space
-
-    dup_sweep_index = sweep_domain.get_space().dim(dim_type.out)
-    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep]
-
-    sp = stor2sweep.get_space()
-    bounds_footprint_map = stor2sweep.move_dims(
-            dim_type.param, sp.dim(dim_type.param),
-            dim_type.out, 0, dup_sweep_index)
+    bounds_footprint_map = move_to_par_from_out(
+            stor2sweep, except_inames=frozenset(primed_sweep_inames))
 
     # compute bounds for each storage axis
     storage_domain = bounds_footprint_map.domain().coalesce()
@@ -207,17 +240,20 @@ def compute_bounds(kernel, sweep_domain, subst_name, stor2sweep, sweep_inames,
 
 
 
-def get_access_info(kernel, sweep_domain, subst_name,
+def get_access_info(kernel, domain, subst_name,
         storage_axis_names, storage_axis_sources,
         sweep_inames, invocation_descriptors):
 
     # {{{ duplicate sweep inames
 
+    # The duplication is necessary, otherwise the storage fetch
+    # inames remain weirdly tied to the original sweep inames.
+
     primed_sweep_inames = [psin+"'" for psin in sweep_inames]
     from loopy.isl_helpers import duplicate_axes
-    dup_sweep_index = sweep_domain.space.dim(dim_type.out)
+    dup_sweep_index = domain.space.dim(dim_type.out)
     domain_dup_sweep = duplicate_axes(
-            sweep_domain, sweep_inames,
+            domain, sweep_inames,
             primed_sweep_inames)
 
     prime_sweep_inames = SubstitutionMapper(make_subst_func(
@@ -226,11 +262,13 @@ def get_access_info(kernel, sweep_domain, subst_name,
     # }}}
 
     stor2sweep = build_global_storage_to_sweep_map(
-            invocation_descriptors, dup_sweep_index, domain_dup_sweep,
-            storage_axis_names, storage_axis_sources, prime_sweep_inames)
+            kernel, invocation_descriptors,
+            domain_dup_sweep, dup_sweep_index,
+            storage_axis_names, storage_axis_sources,
+            sweep_inames, primed_sweep_inames, prime_sweep_inames)
 
     storage_base_indices, storage_shape = compute_bounds(
-            kernel, sweep_domain, subst_name, stor2sweep, sweep_inames,
+            kernel, domain, subst_name, stor2sweep, primed_sweep_inames,
             storage_axis_names)
 
     # compute augmented domain
@@ -285,17 +323,11 @@ def get_access_info(kernel, sweep_domain, subst_name,
     # }}}
 
     # eliminate (primed) storage axes with non-zero base indices
+    aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor)
 
-    aug_domain = aug_domain.eliminate(dim_type.set, stor_idx+nn1_stor, n_stor)
-    aug_domain = aug_domain.remove_dims(dim_type.set, stor_idx+nn1_stor, n_stor)
-
-    # {{{ eliminate duplicated sweep_inames
-
+    # eliminate duplicated sweep_inames
     nsweep = len(sweep_inames)
-    aug_domain = aug_domain.eliminate(dim_type.set, dup_sweep_index, nsweep)
-    aug_domain = aug_domain.remove_dims(dim_type.set, dup_sweep_index, nsweep)
-
-    # }}}
+    aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep)
 
     return (non1_storage_axis_names, aug_domain,
             storage_base_indices, non1_storage_base_indices, non1_storage_shape)
@@ -519,33 +551,28 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     sweep_inames = list(sweep_inames)
 
-    # {{{ see if we need extra storage dimensions
+    # {{{ find inames used in argument dependencies
 
-    # find inames used in argument dependencies
+    expanding_usage_arg_deps = set()
 
-    usage_arg_deps = set()
     for invdesc in invocation_descriptors:
-        if not invdesc.expands_footprint:
-            continue
+        if invdesc.expands_footprint:
+            for arg in invdesc.args:
+                expanding_usage_arg_deps.update(get_dependencies(arg))
 
-        for arg in invdesc.args:
-            usage_arg_deps.update(get_dependencies(arg))
+    # }}}
 
-    extra_storage_axes = list(set(sweep_inames) - usage_arg_deps)
+    newly_created_var_names = set()
+
+    # {{{ use given / find new storage_axes
+
+    extra_storage_axes = list(set(sweep_inames) - expanding_usage_arg_deps)
 
     if storage_axes is None:
         storage_axes = (
                 extra_storage_axes
                 + list(xrange(len(arg_names))))
 
-    # }}}
-
-    newly_created_var_names = set()
-
-    # {{{ process storage_axes argument
-
-    # (and substitute in subst_expressions if any variable name changes are necessary)
-
     expr_subst_dict = {}
 
     storage_axis_names = []
@@ -597,27 +624,33 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     # }}}
 
-    referenced_inames = frozenset(sweep_inames) | frozenset(usage_arg_deps)
-    assert referenced_inames <= kernel.all_inames()
+    expanding_inames = frozenset(sweep_inames) | frozenset(expanding_usage_arg_deps)
+    assert expanding_inames <= kernel.all_inames()
+
+    # {{{ find domain to be changed
 
-    if referenced_inames:
-        leaf_domain_index = kernel.get_leaf_domain_index(referenced_inames)
-        sweep_domain = kernel.domains[leaf_domain_index]
+    from loopy.kernel import DomainChanger
+    domch = DomainChanger(kernel, expanding_inames)
+
+    if domch.leaf_domain_index is not None:
+        # If the sweep inames are at home in parent domains, then we'll add
+        # fetches with loops over copies of these parent inames that will end
+        # up being scheduled *within* loops over these parents.
 
         for iname in sweep_inames:
-            if kernel.get_home_domain_index(iname) != leaf_domain_index:
+            if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                 raise RuntimeError("sweep iname '%s' is not 'at home' in the "
                         "sweep's leaf domain" % iname)
-    else:
-        sweep_domain = kernel.combine_domains(())
-        leaf_domain_index = None
+
+    # }}}
 
     (non1_storage_axis_names, new_domain,
             storage_base_indices, non1_storage_base_indices, non1_storage_shape) = \
-                    get_access_info(kernel, sweep_domain, subst_name,
+                    get_access_info(kernel, domch.domain, subst_name,
                             storage_axis_names, storage_axis_sources,
                             sweep_inames, invocation_descriptors)
 
+
     # {{{ try a few ways to get new_domain to be convex
 
     if len(new_domain.get_basic_sets()) > 1:
@@ -635,6 +668,9 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
     if isinstance(new_domain, isl.Set):
         dom_bsets = new_domain.get_basic_sets()
         if len(dom_bsets) > 1:
+            print "PIECES:"
+            for dbs in dom_bsets:
+                print "  %s" % (isl.Set.from_basic_set(dbs).gist(new_domain))
             raise NotImplementedError("Substitution '%s' yielded a non-convex footprint"
                     % subst_name)
 
@@ -817,14 +853,8 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     # }}}
 
-    new_domains = kernel.domains[:]
-    if leaf_domain_index is not None:
-        new_domains[leaf_domain_index] = new_domain
-    else:
-        new_domains.append(new_domain)
-
     return kernel.copy(
-            domains=new_domains,
+            domains=domch.get_domains_with(new_domain),
             instructions=new_insns,
             substitutions=new_substs,
             temporary_variables=new_temporary_variables,
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 6fd09982a..3c7ee5654 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -1362,9 +1362,9 @@ class LoopKernel(Record):
         domain = self.get_inames_domain(frozenset([iname]))
         d_var_dict = domain.get_var_dict()
 
-        dom_intersect_assumptions = (isl.align_spaces(
-                self.assumptions, domain, obj_bigger_ok=True)
-                & domain)
+        assumptions, domain = isl.align_two(self.assumptions, domain)
+
+        dom_intersect_assumptions = assumptions & domain
 
         lower_bound_pw_aff = (
                 self.cache_manager.dim_min(
@@ -1755,4 +1755,31 @@ class SetOperationCacheManager:
 
 
 
+class DomainChanger:
+    """Helps change the domain responsible for *inames* within a kernel.
+
+    .. note: Does not perform an in-place change!
+    """
+
+    def __init__(self, kernel, inames):
+        self.kernel = kernel
+        if inames:
+            self.leaf_domain_index = kernel.get_leaf_domain_index(inames)
+            self.domain = kernel.domains[self.leaf_domain_index]
+
+        else:
+            self.domain = kernel.combine_domains(())
+            self.leaf_domain_index = None
+
+    def get_domains_with(self, replacement):
+        result = self.kernel.domains[:]
+        if self.leaf_domain_index is not None:
+            result[self.leaf_domain_index] = replacement
+        else:
+            result.append(replacement)
+
+        return result
+
+
+
 # vim: foldmethod=marker
diff --git a/test/test_linalg.py b/test/test_linalg.py
index de14bac1e..3b2cad4dc 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -608,6 +608,40 @@ def test_fancy_matrix_mul(ctx_factory):
 
 
 
+def test_small_batched_matvec(ctx_factory):
+    dtype = np.float32
+    ctx = ctx_factory()
+
+    order = "C"
+
+    K = 10000
+    Np = 36
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "[K] -> {[i,j,k]: 0<=k<K and 0<= i,j < %d}" % Np,
+            [
+                "result[k, i] = sum(j, d[i, j]*f[k, j])"
+                ],
+            [
+                lp.GlobalArg("d", dtype, shape=(Np, Np), order=order),
+                lp.GlobalArg("f", dtype, shape=("K", Np), order=order),
+                lp.GlobalArg("result", dtype, shape=("K", Np), order=order),
+                lp.ValueArg("K", np.int32, approximately=1000),
+                ], name="batched_matvec", assumptions="K>=1")
+
+    seq_knl = knl
+
+    knl = lp.add_prefetch(knl, 'd[:,:]')
+
+    kernel_gen = lp.generate_loop_schedules(knl)
+    kernel_gen = lp.check_kernels(kernel_gen, dict(K=K))
+
+    lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
+            op_count=[K*2*Np**2/1e9], op_label=["GFlops"],
+            parameters=dict(K=K))
+
+
+
 
 if __name__ == "__main__":
     import sys
diff --git a/test/test_nbody.py b/test/test_nbody.py
index aa9812527..90859883b 100644
--- a/test/test_nbody.py
+++ b/test/test_nbody.py
@@ -56,7 +56,8 @@ def test_nbody(ctx_factory):
 
     n = 3000
 
-    for variant in [variant_1, variant_cpu, variant_gpu]:
+    for variant in [ variant_cpu]:
+    #for variant in [variant_1, variant_cpu, variant_gpu]:
         variant_knl, loop_prio = variant(knl)
         kernel_gen = lp.generate_loop_schedules(variant_knl,
                 loop_priority=loop_prio)
diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py
index 84c966757..176f1d004 100644
--- a/test/test_sem_reagan.py
+++ b/test/test_sem_reagan.py
@@ -54,8 +54,8 @@ def test_tim2d(ctx_factory):
     def variant_orig(knl):
         knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1", e="g.0"))
 
-        knl = lp.add_prefetch(knl, "D", ["m", "j", "i","o"])
-        knl = lp.add_prefetch(knl, "u", ["i", "j", "o_ur", "o_us"])
+        knl = lp.add_prefetch(knl, "D[:,:]")
+        knl = lp.add_prefetch(knl, "u[e, :, :]")
 
         knl = lp.precompute(knl, "ur(m,j)", np.float32, ["m", "j"])
         knl = lp.precompute(knl, "us(i,m)", np.float32, ["i", "m"])
-- 
GitLab