From 582139a8bbb0951cc8e6589b6cd35534c7b6df9c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 5 Jun 2013 10:46:53 -0400
Subject: [PATCH] Add argument unpacking for separate-array tagged arguments.

---
 loopy/compiled.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 58 insertions(+), 1 deletion(-)

diff --git a/loopy/compiled.py b/loopy/compiled.py
index c5a44b837..4543e96b0 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -31,10 +31,64 @@ import numpy as np
 from pytools import Record, memoize_method
 
 
+# {{{ object array argument unpacker
+
+class ArgumentUnpacker(object):
+    """For argument arrays with axes tagged to be implemented as separate
+    arrays, this class provides preprocessing of the passed arguments so that
+    all sub-arrays may be passed in one object array (under the original,
+    un-split argument name) and are unpacked into separate arrays before being
+    passed to the kernel.
+    """
+
+    def __init__(self, kernel):
+        # a list of items like (arg_name, [(index, unpacked_name), ...])
+        self.unpackable_args = []
+
+        from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag
+        for arg in kernel.args:
+            if not isinstance(arg, ArrayBase):
+                continue
+
+            if arg.shape is None or arg.dim_tags is None:
+                continue
+
+            log_shape = []
+            for shape_i, dim_tag in zip(arg.shape, arg.dim_tags):
+                if isinstance(dim_tag, SeparateArrayArrayDimTag):
+                    log_shape.append(shape_i)
+
+            if not log_shape:
+                continue
+
+            from pytools import indices_in_shape
+            unpack_data = [
+                    (i, arg.name + "".join("_s%d" % sub_i for sub_i in i))
+                    for i in indices_in_shape(log_shape)]
+
+            self.unpackable_args.append(
+                    (arg.name, unpack_data))
+
+    def __call__(self, kernel_kwargs):
+        kernel_kwargs = kernel_kwargs.copy()
+
+        for arg_name, subscripts_and_names in self.unpackable_args:
+            if arg_name in kernel_kwargs:
+                arg = kernel_kwargs[arg_name]
+                for index, unpacked_name in subscripts_and_names:
+                    assert unpacked_name not in kernel_kwargs
+                    kernel_kwargs[unpacked_name] = arg[index]
+                del kernel_kwargs[arg_name]
+
+        return kernel_kwargs
+
+# }}}
+
+
 # {{{ domain parameter finder
 
 class DomainParameterFinder(object):
-    """Finds parameters from shapes of passed arguments."""
+    """Finds domain parameters from shapes of passed arguments."""
 
     def __init__(self, kernel, cl_arg_info):
         # a mapping from parameter names to a list of tuples
@@ -184,6 +238,8 @@ class CompiledKernel:
         self.codegen_kwargs = codegen_kwargs
         self.options = options
 
+        self.argument_unpacker = ArgumentUnpacker(kernel)
+
     @memoize_method
     def get_kernel_info(self, arg_to_dtype_set):
         kernel = self.kernel
@@ -334,6 +390,7 @@ class CompiledKernel:
 
         # }}}
 
+        kwargs = self.argument_unpacker(kwargs)
         kwargs.update(
                 kernel_info.domain_parameter_finder(kwargs))
 
-- 
GitLab