Skip to content
Snippets Groups Projects
Commit 73951d70 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make a better implementation of duplicate_axes().

parent 5b5cacd1
No related branches found
No related tags found
No related merge requests found
......@@ -52,12 +52,6 @@ To-do
- Deal with equality constraints.
(These arise, e.g., when partitioning a loop of length 16 into 16s.)
- duplicate_dimensions can be implemented without having to muck around
with individual constraints:
- add_dims
- move_dims
- intersect
Future ideas
^^^^^^^^^^^^
......@@ -175,6 +169,11 @@ Dealt with
- Generalize reduction to be over multiple variables
- duplicate_dimensions can be implemented without having to muck around
with individual constraints:
- add_dims
- move_dims
- intersect
Should a dependency on an iname be forced in a CSE?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
......@@ -185,44 +185,44 @@ def static_value_of_pw_aff(pw_aff, constants_only):
def duplicate_axes(basic_set, duplicate_inames, new_inames):
old_to_new = dict((old_iname, new_iname)
for old_iname, new_iname in zip(duplicate_inames, new_inames))
def duplicate_axes(isl_obj, duplicate_inames, new_inames):
# {{{ add dims
start_idx = basic_set.dim(dim_type.set)
result = basic_set.insert_dims(
start_idx = isl_obj.dim(dim_type.set)
more_dims = isl_obj.insert_dims(
dim_type.set, start_idx,
len(duplicate_inames))
old_var_dict = basic_set.get_space().get_var_dict()
dup_iname_dims = [old_var_dict[iname] for iname in duplicate_inames]
for i, iname in enumerate(new_inames):
new_idx = start_idx+i
result = result.set_dim_name(
more_dims = more_dims.set_dim_name(
dim_type.set, new_idx, iname)
set_bs, = basic_set.get_basic_sets()
# }}}
for cns in set_bs.get_constraints():
if any(cns.involves_dims(*dim+(1,)) for dim in dup_iname_dims):
assert not cns.is_div_constraint()
if cns.is_equality():
new_cns = cns.equality_alloc(result.get_space())
else:
new_cns = cns.inequality_alloc(result.get_space())
iname_to_dim = more_dims.get_space().get_var_dict()
new_coeffs = {}
for key, val in cns.get_coefficients_by_name().iteritems():
if key in old_to_new:
new_coeffs[old_to_new[key]] = val
else:
new_coeffs[key] = val
moved_dims = isl_obj
new_cns = new_cns.set_coefficients_by_name(new_coeffs)
result = result.add_constraint(new_cns)
for old_iname, new_iname in zip(duplicate_inames, new_inames):
old_dt, old_idx = iname_to_dim[old_iname]
new_dt, new_idx = iname_to_dim[new_iname]
return result
moved_dims = moved_dims.set_dim_name(
old_dt, old_idx, new_iname)
moved_dims = (moved_dims
.move_dims(
dim_type.param, 0,
old_dt, old_idx, 1)
.move_dims(
new_dt, new_idx-1,
dim_type.param, 0, 1))
moved_dims = moved_dims.insert_dims(old_dt, old_idx, 1)
moved_dims = moved_dims.set_dim_name(
old_dt, old_idx, old_iname)
return moved_dims.intersect(more_dims)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment