diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 750330db17c6d90ffb46191a4c3dc955f1591208..cc1526af53728717fa0e65e11324038fa24ed7da 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -221,6 +221,10 @@ class CLArgumentInfo(Record): .. attribute:: base_name .. attribute:: dtype .. attribute:: shape + .. attribute:: strides + + Strides in multiples of ``dtype.itemsize``. + .. attribute:: offset_for_name """ diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 08da01b98288b377735858282b6ec6a5013edb7b..56c388c69aabdcefdee6c2bd76520e75c419a92f 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -548,6 +548,27 @@ class ArrayBase(Record): return self.copy(**kwargs) + def vector_size(self): + """Return the size of the vector type used for the array + divided by the basic data type. + + Note: For 3-vectors, this will be 4. + """ + + for i, dim_tag in enumerate(self.dim_tags): + if isinstance(dim_tag, VectorArrayDimTag): + shape_i = self.shape[i] + if not isinstance(shape_i, int): + raise RuntimeError("shape of '%s' has non-constant " + "integer axis %d (0-based)" % ( + self.name, user_axis)) + + vec_dtype = cl.array.vec.types[self.dtype, shape_i] + + return int(vec_dtype.itemsize) // int(self.dtype.itemsize) + + return 1 + def decl_info(self, is_written, index_dtype): """Return a list of tuples ``(cgen_decl, arg_info)``, where *cgen_decl* is a :mod:`cgen` argument declarations, *arg_info* @@ -556,7 +577,9 @@ class ArrayBase(Record): from loopy.codegen import CLArgumentInfo - def gen_decls(name_suffix, shape, dtype, user_index): + vector_size = self.vector_size() + + def gen_decls(name_suffix, shape, strides, dtype, user_index): if dtype is None: dtype = self.dtype @@ -566,24 +589,28 @@ class ArrayBase(Record): if num_user_axes is None or user_axis >= num_user_axes: # implemented by various argument types + full_name = self.name + name_suffix + yield (self.get_arg_decl(name_suffix, shape, dtype, is_written), CLArgumentInfo( - name=self.name + name_suffix, + name=full_name, base_name=self.name, dtype=dtype, shape=shape, + strides=strides, offset_for_name=None)) if self.offset: from cgen import Const, POD - yield (Const(POD(index_dtype, - self.name+name_suffix+"_offset")), + offset_name = full_name+"_offset" + yield (Const(POD(index_dtype, offset_name)), CLArgumentInfo( - name=self.name + name_suffix, - base_name=self.name, - dtype=dtype, - shape=shape, - offset_for_name=None)) + name=offset_name, + base_name=None, + dtype=index_dtype, + shape=None, + strides=None, + offset_for_name=full_name)) return @@ -595,8 +622,9 @@ class ArrayBase(Record): else: new_shape = shape + (self.shape[user_axis],) - for res in gen_decls(name_suffix, new_shape, dtype, - user_index + (None,)): + for res in gen_decls(name_suffix, new_shape, + strides + (dim_tag.stride // vector_size,), + dtype, user_index + (None,)): yield res elif isinstance(dim_tag, SeparateArrayArrayDimTag): @@ -608,7 +636,7 @@ class ArrayBase(Record): for i in xrange(shape_i): for res in gen_decls(name_suffix + "_s%d" % i, - shape + (self.shape[user_axis],), dtype, + shape, dtype, user_index + (i,)): yield res @@ -619,7 +647,7 @@ class ArrayBase(Record): "integer axis %d (0-based)" % ( self.name, user_axis)) - for res in gen_decls(name_suffix, shape, + for res in gen_decls(name_suffix, shape, strides, cl.array.vec.types[dtype, shape_i], user_index + (None,)): yield res @@ -628,7 +656,7 @@ class ArrayBase(Record): raise RuntimeError("unsupported array dim implementation tag '%s' " "in array '%s'" % (dim_tag, self.name)) - for res in gen_decls("", (), self.dtype, ()): + for res in gen_decls("", (), (), self.dtype, ()): yield res # }}} @@ -666,9 +694,18 @@ def get_access_info(ary, index, eval_expr): vector_index = None subscripts = [0] * num_target_axes + vector_size = ary.vector_size() + for i, (idx, dim_tag) in enumerate(zip(index, ary.dim_tags)): if isinstance(dim_tag, FixedStrideArrayDimTag): - subscripts[dim_tag.target_axis] += dim_tag.stride*idx + if isinstance(dim_tag.stride, int): + if not dim_tag.stride % vector_size == 0: + raise RuntimeError("stride of axis %d of array '%s' " + "is not a multiple of the vector axis" + % (i, ary.name)) + + subscripts[dim_tag.target_axis] += (dim_tag.stride // vector_size)*idx + elif isinstance(dim_tag, SeparateArrayArrayDimTag): idx = eval_expr(idx) if not isinstance(idx, int): @@ -676,6 +713,7 @@ def get_access_info(ary, index, eval_expr): "index for separate-array axis %d (0-based)" % ( ary.name, index, i)) array_suffix += "_s%d" % idx + elif isinstance(dim_tag, VectorArrayDimTag): idx = eval_expr(idx) @@ -685,6 +723,7 @@ def get_access_info(ary, index, eval_expr): ary.name, index, i)) assert vector_index is None vector_index = idx + else: raise RuntimeError("unsupported array dim implementation tag '%s' " "in array '%s'" % (dim_tag, ary.name))