diff --git a/loopy/compiled.py b/loopy/compiled.py
index 4543e96b0c3e1517c1a7044e10143711aaf134cf..f4b862366975f502e06ad59a5f14d520ea964eae 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -44,6 +44,7 @@ class ArgumentUnpacker(object):
     def __init__(self, kernel):
         # a list of items like (arg_name, [(index, unpacked_name), ...])
         self.unpackable_args = []
+        self.arg_name_to_base_arg_name = {}
 
         from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag
         for arg in kernel.args:
@@ -56,6 +57,10 @@ class ArgumentUnpacker(object):
             log_shape = []
             for shape_i, dim_tag in zip(arg.shape, arg.dim_tags):
                 if isinstance(dim_tag, SeparateArrayArrayDimTag):
+                    if not isinstance(shape_i, int):
+                        raise TypeError("argument '%s' has non-integer "
+                                "separate-array axis" % arg.name)
+
                     log_shape.append(shape_i)
 
             if not log_shape:
@@ -69,6 +74,9 @@ class ArgumentUnpacker(object):
             self.unpackable_args.append(
                     (arg.name, unpack_data))
 
+            for index, sub_arg_name in unpack_data:
+                self.arg_name_to_base_arg_name[sub_arg_name] = arg.name
+
     def __call__(self, kernel_kwargs):
         kernel_kwargs = kernel_kwargs.copy()
 
@@ -248,7 +256,12 @@ class CompiledKernel:
         from loopy.kernel.tools import add_argument_dtypes
 
         if arg_to_dtype_set:
-            kernel = add_argument_dtypes(kernel, dict(arg_to_dtype_set))
+            arg_to_dtype = {}
+            for arg, dtype in arg_to_dtype_set:
+                arg_to_dtype[self.argument_unpacker
+                        .arg_name_to_base_arg_name.get(arg, arg)] = dtype
+
+            kernel = add_argument_dtypes(kernel, arg_to_dtype)
 
             from loopy.preprocess import infer_unknown_types
             kernel = infer_unknown_types(kernel, expect_completion=True)
@@ -365,13 +378,17 @@ class CompiledKernel:
         code_op = kwargs.pop("code_op", None)
         warn_numpy = kwargs.pop("warn_numpy", None)
 
+        kwargs = self.argument_unpacker(kwargs)
+
         # {{{ process arg types, get cl kernel
 
         import loopy as lp
 
         arg_to_dtype = {}
-        for arg in self.kernel.args:
-            val = kwargs.get(arg.name)
+        for arg_name, val in kwargs.iteritems():
+            arg_name = self.argument_unpacker \
+                    .arg_name_to_base_arg_name.get(arg_name, arg_name)
+            arg = self.kernel.arg_dict[arg_name]
 
             if arg.dtype is None and val is not None:
                 try:
@@ -390,7 +407,6 @@ class CompiledKernel:
 
         # }}}
 
-        kwargs = self.argument_unpacker(kwargs)
         kwargs.update(
                 kernel_info.domain_parameter_finder(kwargs))