Skip to content
Snippets Groups Projects
Commit 247e2ff8 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

More data tagging fixes. More tests passing.

parent ab7b11a3
No related branches found
No related tags found
No related merge requests found
......@@ -53,12 +53,8 @@ To-do
- rename IndexTag -> InameTag
- Data implementation tags
TODO initial bringup:
- Adapt padding
- Adapt automatic padding of temp variables
- turn base_indices into offset
TODO further:
- turn base_indices into offset
- vectorization
- automatic copies
- write_image()
......
......@@ -864,7 +864,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
name=target_var_name,
dtype=dtype,
base_indices=(0,)*len(non1_storage_shape),
shape=non1_storage_shape,
shape=tuple(non1_storage_shape),
is_local=None)
new_temporary_variables[target_var_name] = temp_var
......
......@@ -285,6 +285,7 @@ def _parse_shape_or_strides(x):
x = _pymbolic_parse_if_necessary(x)
if isinstance(x, lp.auto):
return x
assert not isinstance(x, list)
if not isinstance(x, tuple):
assert x is not lp.auto
x = (x,)
......
......@@ -262,8 +262,9 @@ class TemporaryVariable(ArrayBase):
# FIXME take into account storage_shape, or something like it
storage_shape = self.shape
for l in storage_shape:
temp_var_decl = ArrayOf(temp_var_decl, l)
if storage_shape:
temp_var_decl = ArrayOf(temp_var_decl,
" * ".join(str(s) for s in storage_shape))
if self.is_local:
temp_var_decl = CLLocal(temp_var_decl)
......
......@@ -108,24 +108,34 @@ def split_arg_axis(kernel, args_and_axes, count):
# }}}
# {{{ adjust strides
# {{{ adjust dim tags
new_strides = list(arg.strides)
old_stride = new_strides[axis]
if arg.dim_tags is None:
raise RuntimeError("dim_tags of '%s' are not known" % arg.name)
new_dim_tags = list(arg.dim_tags)
old_dim_tag = arg.dim_tags[axis]
from loopy.kernel.array import FixedStrideArrayDimTag
if not isinstance(old_dim_tag, FixedStrideArrayDimTag):
raise RuntimeError("axis %d of '%s' is not tagged fixed-stride"
% (axis, arg.name))
old_stride = old_dim_tag.stride
outer_stride = count*old_stride
if order == "F":
new_strides.insert(axis+1, outer_stride)
new_dim_tags.insert(axis+1, FixedStrideArrayDimTag(outer_stride))
elif order == "C":
new_strides.insert(axis, outer_stride)
new_dim_tags.insert(axis, FixedStrideArrayDimTag(outer_stride))
else:
raise RuntimeError("order '%s' not understood" % order)
new_strides = tuple(new_strides)
new_dim_tags = tuple(new_dim_tags)
# }}}
new_args[arg_idx] = arg.copy(shape=new_shape, strides=new_strides)
new_args[arg_idx] = arg.copy(shape=new_shape, dim_tags=new_dim_tags)
# }}}
......@@ -187,9 +197,22 @@ def split_arg_axis(kernel, args_and_axes, count):
def find_padding_multiple(kernel, variable, axis, align_bytes, allowed_waste=0.1):
arg = kernel.arg_dict[variable]
stride = arg.strides[axis]
if arg.dim_tags is None:
raise RuntimeError("cannot find padding multiple--dim_tags of '%s' "
"are not known" % variable)
dim_tag = arg.dim_tags[axis]
from loopy.kernel.array import FixedStrideArrayDimTag
if not isinstance(dim_tag, FixedStrideArrayDimTag):
raise RuntimeError("cannot find padding multiple--"
"axis %d of '%s' is not tagged fixed-stride"
% (axis, variable))
stride = dim_tag.stride
if not isinstance(stride, int):
raise RuntimeError("cannot find padding multi--stride is not a "
raise RuntimeError("cannot find padding multiple--stride is not a "
"known integer")
from pytools import div_ceil
......@@ -212,21 +235,31 @@ def add_padding(kernel, variable, axis, align_bytes):
new_args = kernel.args[:]
arg = new_args[arg_idx]
new_strides = list(arg.strides)
stride = new_strides[axis]
if arg.dim_tags is None:
raise RuntimeError("cannot add padding--dim_tags of '%s' "
"are not known" % variable)
new_dim_tags = list(arg.dim_tags)
dim_tag = new_dim_tags[axis]
from loopy.kernel.array import FixedStrideArrayDimTag
if not isinstance(dim_tag, FixedStrideArrayDimTag):
raise RuntimeError("cannot find padding multiple--"
"axis %d of '%s' is not tagged fixed-stride"
% (axis, variable))
stride = dim_tag.stride
if not isinstance(stride, int):
raise RuntimeError("cannot find split granularity--stride is not a "
"known integer")
from pytools import div_ceil
new_strides[axis] = div_ceil(stride, align_bytes) * align_bytes
new_dim_tags[axis] = FixedStrideArrayDimTag(
div_ceil(stride, align_bytes) * align_bytes)
new_args[arg_idx] = arg.copy(strides=tuple(new_strides))
new_args[arg_idx] = arg.copy(dim_tags=tuple(new_dim_tags))
return kernel.copy(args=new_args)
# vim: foldmethod=marker
......@@ -710,14 +710,24 @@ def get_auto_axis_iname_ranking_by_stride(kernel, insn):
ary_name = aae.aggregate.name
arg = kernel.arg_dict.get(ary_name)
ary_strides = arg.strides
if ary_strides is None and len(index_expr) == 1:
ary_strides = (1,)
if arg.dim_tags is None:
from warnings import warn
warn("Strides for '%s' are not known. Local axis assignment "
"is likely suboptimal." % arg.name)
ary_strides = [1] * len(index_expr)
else:
ary_strides = []
from loopy.kernel.array import FixedStrideArrayDimTag
for dim_tag in arg.dim_tags:
if isinstance(dim_tag, FixedStrideArrayDimTag):
ary_strides.append(dim_tag.stride)
# {{{ construct iname_to_stride_expr
iname_to_stride_expr = {}
for iexpr_i, stride in zip(index_expr, ary_strides):
if stride is None:
continue
coeffs = CoefficientCollector()(iexpr_i)
for var_name, coeff in coeffs.iteritems():
if var_name in auto_axis_inames: # excludes '1', i.e. the constant
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment