Skip to content
Snippets Groups Projects
Commit 45657b41 authored by Tim Warburton's avatar Tim Warburton
Browse files

Factor out loopy.isl_helpers.duplicate_axes from CSE realization.

parent 63893d99
No related branches found
No related tags found
No related merge requests found
......@@ -358,48 +358,9 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non
# {{{ build new domain, duplicating each constraint on duplicated inames
dup_iname_dims = [kernel.iname_to_dim[iname]
for iname in duplicate_inames]
old_to_new = dict((old_iname, new_iname)
for old_iname, new_iname in zip(duplicate_inames, new_inames))
def realize_duplication(set):
start_idx = set.dim(dim_type.set)
result = set.insert_dims(
dim_type.set, start_idx,
len(duplicate_inames))
new_iname_to_dim = kernel.iname_to_dim.copy()
for i, iname in enumerate(new_inames):
new_idx = start_idx+i
result = result.set_dim_name(
dim_type.set, new_idx, iname)
new_iname_to_dim[iname] = (dim_type.set, new_idx)
set_bs, = set.get_basic_sets()
for cns in set_bs.get_constraints():
if any(cns.involves_dims(*dim+(1,)) for dim in dup_iname_dims):
assert not cns.is_div_constraint()
if cns.is_equality():
new_cns = cns.equality_alloc(result.get_space())
else:
new_cns = cns.inequality_alloc(result.get_space())
new_coeffs = {}
for key, val in cns.get_coefficients_by_name().iteritems():
if key in old_to_new:
new_coeffs[old_to_new[key]] = val
else:
new_coeffs[key] = val
new_cns = new_cns.set_coefficients_by_name(new_coeffs)
result = result.add_constraint(new_cns)
return result, new_iname_to_dim
new_domain, new_iname_to_dim = realize_duplication(kernel.domain)
from loopy.isl_helpers import duplicate_axes
new_domain = duplicate_axes(kernel.domain, duplicate_inames, new_inames)
new_iname_to_dim = new_domain.get_space().get_var_dict()
# }}}
......
......@@ -185,4 +185,46 @@ def static_value_of_pw_aff(pw_aff, constants_only):
def duplicate_axes(basic_set, duplicate_inames, new_inames):
old_to_new = dict((old_iname, new_iname)
for old_iname, new_iname in zip(duplicate_inames, new_inames))
start_idx = basic_set.dim(dim_type.set)
result = basic_set.insert_dims(
dim_type.set, start_idx,
len(duplicate_inames))
old_var_dict = basic_set.get_space().get_var_dict()
dup_iname_dims = [old_var_dict[iname] for iname in duplicate_inames]
for i, iname in enumerate(new_inames):
new_idx = start_idx+i
result = result.set_dim_name(
dim_type.set, new_idx, iname)
set_bs, = basic_set.get_basic_sets()
for cns in set_bs.get_constraints():
if any(cns.involves_dims(*dim+(1,)) for dim in dup_iname_dims):
assert not cns.is_div_constraint()
if cns.is_equality():
new_cns = cns.equality_alloc(result.get_space())
else:
new_cns = cns.inequality_alloc(result.get_space())
new_coeffs = {}
for key, val in cns.get_coefficients_by_name().iteritems():
if key in old_to_new:
new_coeffs[old_to_new[key]] = val
else:
new_coeffs[key] = val
new_cns = new_cns.set_coefficients_by_name(new_coeffs)
result = result.add_constraint(new_cns)
return result
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment