From c00552272a7b9a7fe05b98dc3710833145b3bc7e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 2 Sep 2012 14:20:17 -0400
Subject: [PATCH] Add (multi-)padding transformation. Adapt test
 infrastructure.

---
 MEMO                |   4 +
 doc/reference.rst   |  18 +++-
 loopy/__init__.py   |  12 ++-
 loopy/compiled.py   | 230 ++++++++++++++++++++++++++++++--------------
 loopy/kernel.py     |  13 ++-
 loopy/padding.py    | 213 ++++++++++++++++++++++++++++++++++++++++
 test/test_linalg.py |   8 +-
 7 files changed, 417 insertions(+), 81 deletions(-)
 create mode 100644 loopy/padding.py

diff --git a/MEMO b/MEMO
index 7d1fa43fd..b5ab32e62 100644
--- a/MEMO
+++ b/MEMO
@@ -45,6 +45,10 @@ To-do
 
 - Kernel splitting (via what variables get computed in a kernel)
 
+- Make xfail test for strided access.
+
+- *_dimension -> *_iname
+
 Fixes:
 
 - Group instructions by dependency/inames for scheduling, to
diff --git a/doc/reference.rst b/doc/reference.rst
index 9b9784bef..2cc0578bd 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -177,10 +177,14 @@ Precomputation and Prefetching
 
     Uses :func:`extract_subst` and :func:`precompute`.
 
-Manipulating Reductions
------------------------
+Padding
+-------
 
-.. autofunction:: realize_reduction
+.. autofunction:: split_arg_axis
+
+.. autofunction:: find_padding_multiple
+
+.. autofunction:: add_padding
 
 Manipulating Instructions
 -------------------------
@@ -206,6 +210,14 @@ Automatic Testing
 Troubleshooting
 ---------------
 
+Special-purpose functionality
+-----------------------------
+
+Manipulating Reductions
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: realize_reduction
+
 Printing :class:`LoopKernel` objects
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index d6bdb4f55..9c6152307 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -29,6 +29,8 @@ from loopy.creation import make_kernel
 from loopy.reduction import register_reduction_parser
 from loopy.subst import extract_subst, expand_subst
 from loopy.cse import precompute
+from loopy.padding import (split_arg_axis, find_padding_multiple,
+        add_padding)
 from loopy.preprocess import preprocess_kernel, realize_reduction
 from loopy.schedule import generate_loop_schedules
 from loopy.codegen import generate_code
@@ -51,7 +53,8 @@ __all__ = ["ValueArg", "ScalarArg", "GlobalArg", "ArrayArg", "ConstantArg", "Ima
         "make_kernel", "split_dimension", "join_dimensions",
         "tag_dimensions",
         "extract_subst", "expand_subst",
-        "precompute", "add_prefetch"
+        "precompute", "add_prefetch",
+        "split_arg_axis", "find_padding_multiple", "add_padding"
         ]
 
 class infer_type:
@@ -79,9 +82,11 @@ def split_dimension(kernel, split_iname, inner_length,
     applied_iname_rewrites = kernel.applied_iname_rewrites[:]
 
     if outer_iname is None:
-        outer_iname = split_iname+"_outer"
+        outer_iname = kernel.make_unique_var_name(
+                split_iname+"_outer")
     if inner_iname is None:
-        inner_iname = split_iname+"_inner"
+        inner_iname = kernel.make_unique_var_name(
+                split_iname+"_inner")
 
     def process_set(s):
         var_dict = s.get_var_dict()
@@ -598,4 +603,5 @@ def add_dependency(kernel, insn_match, dependency):
 
 
 
+
 # vim: foldmethod=marker
diff --git a/loopy/compiled.py b/loopy/compiled.py
index f4cdd67a4..c8b733d3d 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -4,6 +4,8 @@ import pyopencl.array as cl_array
 
 import numpy as np
 
+from pytools import Record
+
 
 
 
@@ -14,6 +16,7 @@ def _arg_matches_spec(arg, val, other_args):
     if isinstance(arg, lp.GlobalArg):
         from pymbolic import evaluate
         shape = evaluate(arg.shape, other_args)
+        strides = evaluate(arg.numpy_strides, other_args)
 
         if arg.dtype != val.dtype:
             raise TypeError("dtype mismatch on argument '%s' "
@@ -23,15 +26,10 @@ def _arg_matches_spec(arg, val, other_args):
             raise TypeError("shape mismatch on argument '%s' "
                     "(got: %s, expected: %s)"
                     % (arg.name, val.shape, shape))
-        if arg.order == "F" and not val.flags.f_contiguous:
-            raise TypeError("order mismatch on argument '%s' "
-                    "(expected Fortran-contiguous, but isn't)"
-                    % (arg.name))
-        if arg.order == "C" and not val.flags.c_contiguous:
-            print id(val), val.flags
-            raise TypeError("order mismatch on argument '%s' "
-                    "(expected C-contiguous, but isn't)"
-                    % (arg.name))
+        if strides != tuple(val.strides):
+            raise ValueError("strides mismatch on argument '%s' "
+                    "(got: %s, expected: %s)"
+                    % (arg.name, val.strides, strides))
 
     return True
 
@@ -152,7 +150,7 @@ class CompiledKernel:
 
             self.needs_check = False
 
-        domain_parameters = dict((name, kwargs[name])
+        domain_parameters = dict((name, int(kwargs[name]))
                 for name in self.kernel.scalar_loop_args)
 
         args = []
@@ -184,8 +182,15 @@ class CompiledKernel:
 
                 from pymbolic import evaluate
                 shape = evaluate(arg.shape, kwargs)
-                val = cl_array.empty(queue, shape, arg.dtype, order=arg.order,
-                        allocator=allocator)
+                numpy_strides = evaluate(arg.numpy_strides, kwargs)
+
+                from pytools import all
+                assert all(s > 0 for s in numpy_strides)
+                alloc_size = sum(astrd*(alen-1)
+                        for alen, astrd in zip(shape, numpy_strides)) + 1
+
+                storage = cl_array.empty(queue, alloc_size, arg.dtype)
+                val = cl_array.as_strided(storage, shape, numpy_strides)
             else:
                 assert _arg_matches_spec(arg, val, kwargs)
 
@@ -254,15 +259,18 @@ def fill_rand(ary):
 
 
 
+
+class TestArgInfo(Record):
+    pass
+
 def make_ref_args(kernel, queue, parameters,
         fill_value):
     from loopy.kernel import ValueArg, GlobalArg, ImageArg
 
     from pymbolic import evaluate
 
-    result = []
-    input_arrays = []
-    output_arrays = []
+    ref_args = {}
+    arg_descriptors = []
 
     for arg in kernel.args:
         if isinstance(arg, ValueArg):
@@ -276,7 +284,9 @@ def make_ref_args(kernel, queue, parameters,
             if argv_dtype != arg.dtype:
                 arg_value = arg.dtype.type(arg_value)
 
-            result.append(arg_value)
+            ref_args[arg.name] = arg_value
+
+            arg_descriptors.append(None)
 
         elif isinstance(arg, (GlobalArg, ImageArg)):
             if arg.shape is None:
@@ -284,51 +294,77 @@ def make_ref_args(kernel, queue, parameters,
                         "testing")
 
             shape = evaluate(arg.shape, parameters)
-            if isinstance(arg, ImageArg):
-                order = "C"
+
+            is_output = arg.name in kernel.get_written_variables()
+            is_image = isinstance(arg, ImageArg)
+
+            if is_image:
+                storage_array = ary = cl_array.empty(queue, shape, arg.dtype, order="C")
+                numpy_strides = None
+                alloc_size = None
+                strides = None
             else:
-                order = arg.order
                 assert arg.offset == 0
 
-            ary = cl_array.empty(queue, shape, arg.dtype, order=order)
-            if arg.name in kernel.get_written_variables():
-                if isinstance(arg, ImageArg):
+                strides = evaluate(arg.strides, parameters)
+
+                from pytools import all
+                assert all(s > 0 for s in strides)
+                alloc_size = sum(astrd*(alen-1)
+                        for alen, astrd in zip(shape, strides)) + 1
+
+                itemsize = arg.dtype.itemsize
+                numpy_strides = [itemsize*s for s in strides]
+
+                storage_array = cl_array.empty(queue, alloc_size, arg.dtype)
+                ary = cl_array.as_strided(storage_array, shape, numpy_strides)
+
+            if is_output:
+                if is_image:
                     raise RuntimeError("write-mode images not supported in "
                             "automatic testing")
 
                 if arg.dtype.isbuiltin:
-                    ary.fill(fill_value)
+                    storage_array.fill(fill_value)
                 else:
                     from warnings import warn
                     warn("Cannot pre-fill array of dtype '%s'" % arg.dtype)
 
-                output_arrays.append(ary)
-                result.append(ary.data)
+                ref_args[arg.name] = ary
             else:
-                fill_rand(ary)
-                input_arrays.append(ary)
+                fill_rand(storage_array)
                 if isinstance(arg, ImageArg):
-                    result.append(cl.image_from_array(queue.context, ary.get(), 1))
+                    # must be contiguous
+                    ref_args[arg.name] = cl.image_from_array(queue.context, ary.get(), 1)
                 else:
-                    result.append(ary.data)
-
+                    ref_args[arg.name] = ary
+
+            arg_descriptors.append(
+                    TestArgInfo(
+                        name=arg.name,
+                        ref_array=ary,
+                        ref_storage_array=storage_array,
+                        ref_shape=shape,
+                        ref_strides=strides,
+                        ref_alloc_size=alloc_size,
+                        ref_numpy_strides=numpy_strides,
+                        needs_checking=is_output))
         else:
             raise RuntimeError("arg type not understood")
 
-    return result, input_arrays, output_arrays
+    return ref_args, arg_descriptors
 
 
 
 
-def make_args(queue, kernel, ref_input_arrays, parameters,
+def make_args(queue, kernel, arg_descriptors, parameters,
         fill_value):
     from loopy.kernel import ValueArg, GlobalArg, ImageArg
 
     from pymbolic import evaluate
 
-    result = []
-    output_arrays = []
-    for arg in kernel.args:
+    args = {}
+    for arg, arg_desc in zip(kernel.args, arg_descriptors):
         if isinstance(arg, ValueArg):
             arg_value = parameters[arg.name]
 
@@ -340,39 +376,84 @@ def make_args(queue, kernel, ref_input_arrays, parameters,
             if argv_dtype != arg.dtype:
                 arg_value = arg.dtype.type(arg_value)
 
-            result.append(arg_value)
+            args[arg.name] = arg_value
 
-        elif isinstance(arg, (GlobalArg, ImageArg)):
+        elif isinstance(arg, ImageArg):
             if arg.name in kernel.get_written_variables():
-                if isinstance(arg, ImageArg):
-                    raise RuntimeError("write-mode images not supported in "
-                            "automatic testing")
+                raise NotImplementedError("write-mode images not supported in "
+                        "automatic testing")
 
-                shape = evaluate(arg.shape, parameters)
-                ary = cl_array.empty(queue, shape, arg.dtype, order=arg.order)
+            shape = evaluate(arg.shape, parameters)
+            assert shape == arg_desc.ref_shape
+
+            # must be contiguous
+            args[arg.name] = cl.image_from_array(
+                    queue.context, arg_desc.ref_array.get(), 1)
+
+        elif isinstance(arg, GlobalArg):
+            assert arg.offset == 0
+
+            shape = evaluate(arg.shape, parameters)
+            strides = evaluate(arg.strides, parameters)
+
+            itemsize = arg.dtype.itemsize
+            numpy_strides = [itemsize*s for s in strides]
+
+            assert all(s > 0 for s in strides)
+            alloc_size = sum(astrd*(alen-1)
+                    for alen, astrd in zip(shape, strides)) + 1
+
+            if arg.name in kernel.get_written_variables():
+                storage_array = cl_array.empty(queue, alloc_size, arg.dtype)
+                ary = cl_array.as_strided(storage_array, shape, numpy_strides)
 
                 if arg.dtype.isbuiltin:
-                    ary.fill(fill_value)
+                    storage_array.fill(fill_value)
                 else:
                     from warnings import warn
                     warn("Cannot pre-fill array of dtype '%s'" % arg.dtype)
 
-                assert arg.offset == 0
-                output_arrays.append(ary)
-                result.append(ary.data)
+                args[arg.name] = ary
             else:
-                ref_arg = ref_input_arrays.pop(0)
+                # use contiguous array to transfer to host
+                host_ref_contig_array = arg_desc.ref_storage_array.get()
 
-                if isinstance(arg, ImageArg):
-                    result.append(cl.image_from_array(queue.context, ref_arg.get(), 1))
-                else:
-                    ary = cl_array.to_device(queue, ref_arg.get())
-                    result.append(ary.data)
+                # use device shape/strides
+                from pyopencl.compyte.array import as_strided
+                host_ref_array = as_strided(host_ref_contig_array,
+                        arg_desc.ref_shape, arg_desc.ref_numpy_strides)
+
+                # flatten the thing
+                host_ref_flat_array = host_ref_array.flatten()
+
+                # create host array with test shape (but not strides)
+                host_contig_array = np.empty(shape, dtype=arg.dtype)
+
+                common_len = min(len(host_ref_flat_array), len(host_contig_array.ravel()))
+                host_contig_array.ravel()[:common_len] = host_ref_flat_array[:common_len]
+
+                # create host array with test shape and storage layout
+                host_storage_array = np.empty(alloc_size, arg.dtype)
+                host_array = as_strided(host_storage_array, shape, numpy_strides)
+                host_array[:] = host_contig_array
+
+                host_contig_array = arg_desc.ref_storage_array.get()
+                storage_array = cl_array.to_device(queue, host_storage_array)
+                ary = cl_array.as_strided(storage_array, shape, numpy_strides)
+
+                args[arg.name] = ary
+
+            arg_desc.test_storage_array = storage_array
+            arg_desc.test_array = ary
+            arg_desc.test_shape = shape
+            arg_desc.test_strides = strides
+            arg_desc.test_numpy_strides = numpy_strides
+            arg_desc.test_alloc_size = alloc_size
 
         else:
             raise RuntimeError("arg type not understood")
 
-    return result, output_arrays
+    return args
 
 
 
@@ -478,7 +559,7 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], paramet
             print 75*"-"
 
         try:
-            ref_args, ref_input_arrays, ref_output_arrays = \
+            ref_args, arg_descriptors = \
                     make_ref_args(ref_sched_kernel, ref_queue, parameters,
                             fill_value=fill_value_ref)
         except cl.RuntimeError, e:
@@ -490,14 +571,7 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], paramet
 
         print "using %s for the reference calculation" % dev
 
-        domain_parameters = dict((name, parameters[name])
-                for name in ref_knl.scalar_loop_args)
-
-        ref_evt = ref_compiled.cl_kernel(ref_queue,
-                ref_compiled.global_size_func(**domain_parameters),
-                ref_compiled.local_size_func(**domain_parameters),
-                *ref_args,
-                g_times_l=True)
+        ref_evt, _ = ref_compiled(ref_queue, **ref_args)
 
         ref_queue.finish()
         ref_stop = time()
@@ -518,7 +592,7 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], paramet
     args = None
     for i, kernel in enumerate(kernel_gen):
         if args is None:
-            args, output_arrays = make_args(queue, kernel, ref_input_arrays, parameters,
+            args = make_args(queue, kernel, arg_descriptors, parameters,
                     fill_value=fill_value)
 
         compiled = CompiledKernel(ctx, kernel, edit_code=edit_code,
@@ -538,14 +612,30 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], paramet
 
         do_check = True
 
-        gsize = compiled.global_size_func(**domain_parameters)
-        lsize = compiled.local_size_func(**domain_parameters)
         for i in range(warmup_rounds):
-            evt = compiled.cl_kernel(queue, gsize, lsize, *args, g_times_l=True)
+            evt, _ = compiled(queue, **args)
 
             if do_check:
-                for ref_out_ary, out_ary in zip(ref_output_arrays, output_arrays):
-                    error_is_small, error = check_result(out_ary.get(), ref_out_ary.get())
+                for arg_desc in arg_descriptors:
+                    if arg_desc is None:
+                        continue
+                    if not arg_desc.needs_checking:
+                        continue
+
+                    from pyopencl.compyte.array import as_strided
+                    ref_ary = as_strided(
+                            arg_desc.ref_storage_array.get(),
+                            shape=arg_desc.ref_shape,
+                            strides=arg_desc.ref_numpy_strides).flatten()
+                    test_ary = as_strided(
+                            arg_desc.test_storage_array.get(),
+                            shape=arg_desc.test_shape,
+                            strides=arg_desc.test_numpy_strides).flatten()
+                    common_len = min(len(ref_ary), len(test_ary))
+                    ref_ary = ref_ary[:common_len]
+                    test_ary = test_ary[:common_len]
+
+                    error_is_small, error = check_result(test_ary, ref_ary)
                     assert error_is_small, error
                     do_check = False
 
@@ -561,8 +651,8 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], paramet
             evt_start = cl.enqueue_marker(queue)
 
             for i in range(timing_rounds):
-                events.append(
-                        compiled.cl_kernel(queue, gsize, lsize, *args, g_times_l=True))
+                evt, _ = compiled(queue, **args)
+                events.append(evt)
 
             evt_end = cl.enqueue_marker(queue)
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 30abc2561..ededdc530 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -170,8 +170,12 @@ class _ShapedArg(Record):
                 dtype=dtype,
                 strides=strides,
                 offset=offset,
-                shape=shape,
-                order=order)
+                shape=shape)
+
+    @property
+    @memoize_method
+    def numpy_strides(self):
+        return tuple(self.dtype.itemsize*s for s in self.strides)
 
     @property
     def dimensions(self):
@@ -1573,7 +1577,10 @@ class LoopKernel(Record):
         if exclude_instructions:
             new_insns = self.instructions
         else:
-            new_insns = [insn.copy(expression=func(insn.expression))
+            new_insns = [insn.copy(
+                expression=func(insn.expression),
+                assignee=func(insn.assignee),
+                )
                     for insn in self.instructions]
 
         return self.copy(
diff --git a/loopy/padding.py b/loopy/padding.py
new file mode 100644
index 000000000..c3b67a252
--- /dev/null
+++ b/loopy/padding.py
@@ -0,0 +1,213 @@
+from __future__ import division
+from loopy.symbolic import IdentityMapper
+
+
+
+
+class ArgAxisSplitHelper(IdentityMapper):
+    def __init__(self, arg_names, handler):
+        self.arg_names = arg_names
+        self.handler = handler
+
+    def map_subscript(self, expr):
+        if expr.aggregate.name in self.arg_names:
+            return self.handler(expr)
+        else:
+            return IdentityMapper.map_subscript(self, expr)
+
+
+
+
+
+def split_arg_axis(kernel, args_and_axes, count):
+    """
+    :arg args_and_axes: a list of tuples *(arg, axis_nr)* indicating
+        that the index in *axis_nr* should be split. The tuples may
+        also be *(arg, axis_nr, "F")*, indicating that the index will
+        be split as it would according to Fortran order.
+
+        If *args_and_axes* is a :class:`tuple`, it is automatically
+        wrapped in a list, to make single splits easier.
+    """
+
+    if count == 1:
+        return kernel
+
+    def normalize_rest(rest):
+        if len(rest) == 1:
+            return (rest[0], "C")
+        elif len(rest) == 2:
+            return rest
+        else:
+            raise RuntimeError("split instruction '%s' not understood" % rest)
+
+    if isinstance(args_and_axes, tuple):
+        args_and_axes = [args_and_axes]
+
+    arg_to_rest = dict((tup[0], normalize_rest(tup[1:])) for tup in args_and_axes)
+
+    if len(args_and_axes) != len(arg_to_rest):
+        raise RuntimeError("cannot split multiple axes of the same variable")
+
+    from loopy.kernel import GlobalArg
+    for arg_name in arg_to_rest:
+        if not isinstance(kernel.arg_dict[arg_name], GlobalArg):
+            raise RuntimeError("only GlobalArg axes may be split")
+
+    arg_to_idx = dict((arg.name, i) for i, arg in enumerate(kernel.args))
+
+    # {{{ adjust args
+
+    new_args = kernel.args[:]
+    for arg_name, (axis, order) in arg_to_rest.iteritems():
+        arg_idx = arg_to_idx[arg_name]
+
+        arg = new_args[arg_idx]
+
+        from pytools import div_ceil
+
+        # {{{ adjust shape
+
+        new_shape = arg.shape
+        if new_shape is not None:
+            new_shape = list(new_shape)
+            axis_len = new_shape[axis]
+            new_shape[axis] = count
+            outer_len = div_ceil(axis_len, count)
+
+            if order == "F":
+                new_shape.insert(axis+1, outer_len)
+            elif order == "C":
+                new_shape.insert(axis, outer_len)
+            else:
+                raise RuntimeError("order '%s' not understood" % order)
+            new_shape = tuple(new_shape)
+
+        # }}}
+
+        # {{{ adjust strides
+
+        new_strides = list(arg.strides)
+        old_stride = new_strides[axis]
+        outer_stride = count*old_stride
+
+        if order == "F":
+            new_strides.insert(axis+1, outer_stride)
+        elif order == "C":
+            new_strides.insert(axis, outer_stride)
+        else:
+            raise RuntimeError("order '%s' not understood" % order)
+
+        new_strides = tuple(new_strides)
+
+        # }}}
+
+        new_args[arg_idx] = arg.copy(shape=new_shape, strides=new_strides)
+
+    # }}}
+
+    split_vars = {}
+
+    def split_access_axis(expr):
+        axis_nr, order = arg_to_rest[expr.aggregate.name]
+
+        idx = expr.index
+        if not isinstance(idx, tuple):
+            idx = (idx,)
+        idx = list(idx)
+
+        axis_idx = idx[axis_nr]
+        from pymbolic.primitives import Variable
+        if not isinstance(axis_idx, Variable):
+            raise RuntimeError("found access '%s' in which axis %d is not a "
+                    "single variable--cannot split" % (expr, axis_nr))
+
+        split_iname = expr.index[axis_nr].name
+        assert split_iname in kernel.all_inames()
+
+        try:
+            outer_iname, inner_iname = split_vars[split_iname]
+        except KeyError:
+            outer_iname = kernel.make_unique_var_name(
+                    split_iname+"_outer")
+            inner_iname = kernel.make_unique_var_name(
+                    split_iname+"_inner")
+            split_vars[split_iname] = outer_iname, inner_iname
+
+
+        idx[axis_nr] = Variable(inner_iname)
+
+        if order == "F":
+            idx.insert(axis+1, Variable(outer_iname))
+        elif order == "C":
+            idx.insert(axis, Variable(outer_iname))
+        else:
+            raise RuntimeError("order '%s' not understood" % order)
+
+        return expr.aggregate[tuple(idx)]
+
+    aash = ArgAxisSplitHelper(arg_name, split_access_axis)
+
+    result = (kernel
+            .map_expressions(aash)
+            .copy(args=new_args))
+
+    from loopy import split_dimension
+
+    for split_iname, (outer_iname, inner_iname) in split_vars.iteritems():
+        result = split_dimension(result, split_iname, count,
+                outer_iname=outer_iname, inner_iname=inner_iname)
+
+    return result
+
+
+
+
+def find_padding_multiple(kernel, variable, axis, align_bytes, allowed_waste=0.1):
+    arg = kernel.arg_dict[variable]
+
+    stride = arg.strides[axis]
+    if not isinstance(stride, int):
+        raise RuntimeError("cannot find padding multi--stride is not a "
+                "known integer")
+
+    from pytools import div_ceil
+
+    multiple = 1
+    while True:
+        true_size = multiple * stride
+        padded_size = div_ceil(true_size, align_bytes) * align_bytes
+
+        if (padded_size - true_size) / true_size <= allowed_waste:
+            return multiple
+
+        multiple += 1
+
+
+
+
+def add_padding(kernel, variable, axis, align_bytes):
+    arg_to_idx = dict((arg.name, i) for i, arg in enumerate(kernel.args))
+    arg_idx = arg_to_idx[variable]
+
+    new_args = kernel.args[:]
+    arg = new_args[arg_idx]
+
+    new_strides = list(arg.strides)
+    stride = new_strides[axis]
+    if not isinstance(stride, int):
+        raise RuntimeError("cannot find split granularity--stride is not a "
+                "known integer")
+    from pytools import div_ceil
+    new_strides[axis] = div_ceil(stride, align_bytes) * align_bytes
+
+    new_args[arg_idx] = arg.copy(strides=tuple(new_strides))
+
+    return kernel.copy(args=new_args)
+
+
+
+
+
+
+# vim: foldmethod=marker
diff --git a/test/test_linalg.py b/test/test_linalg.py
index f008b5865..00b63ec6f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -94,7 +94,7 @@ def check_float4(result, ref_result):
 def test_axpy(ctx_factory):
     ctx = ctx_factory()
 
-    n = 20*1024**2
+    n = 3145182
 
     vec = cl_array.vec
 
@@ -644,7 +644,7 @@ def test_small_batched_matvec(ctx_factory):
 
     order = "C"
 
-    K = 10000
+    K = 9997
     Np = 36
 
     knl = lp.make_kernel(ctx.devices[0],
@@ -661,7 +661,11 @@ def test_small_batched_matvec(ctx_factory):
 
     seq_knl = knl
 
+    align_bytes = 64
     knl = lp.add_prefetch(knl, 'd[:,:]')
+    pad_mult = lp.find_padding_multiple(knl, "f", 0, align_bytes)
+    knl = lp.split_arg_axis(knl, ("f", 0), pad_mult)
+    knl = lp.add_padding(knl, "f", 0, align_bytes)
 
     kernel_gen = lp.generate_loop_schedules(knl)
     kernel_gen = lp.check_kernels(kernel_gen, dict(K=K))
-- 
GitLab