diff --git a/loopy/__init__.py b/loopy/__init__.py index 5e36fc11159b5c0483995f6b5432f4bbb370fecf..aa2947e47dc3443dc861942def3b0bc096445a5f 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -358,48 +358,9 @@ def realize_cse(kernel, cse_tag, dtype, duplicate_inames=[], parallel_inames=Non # {{{ build new domain, duplicating each constraint on duplicated inames - dup_iname_dims = [kernel.iname_to_dim[iname] - for iname in duplicate_inames] - old_to_new = dict((old_iname, new_iname) - for old_iname, new_iname in zip(duplicate_inames, new_inames)) - - def realize_duplication(set): - start_idx = set.dim(dim_type.set) - result = set.insert_dims( - dim_type.set, start_idx, - len(duplicate_inames)) - - new_iname_to_dim = kernel.iname_to_dim.copy() - for i, iname in enumerate(new_inames): - new_idx = start_idx+i - result = result.set_dim_name( - dim_type.set, new_idx, iname) - new_iname_to_dim[iname] = (dim_type.set, new_idx) - - - set_bs, = set.get_basic_sets() - - for cns in set_bs.get_constraints(): - if any(cns.involves_dims(*dim+(1,)) for dim in dup_iname_dims): - assert not cns.is_div_constraint() - if cns.is_equality(): - new_cns = cns.equality_alloc(result.get_space()) - else: - new_cns = cns.inequality_alloc(result.get_space()) - - new_coeffs = {} - for key, val in cns.get_coefficients_by_name().iteritems(): - if key in old_to_new: - new_coeffs[old_to_new[key]] = val - else: - new_coeffs[key] = val - - new_cns = new_cns.set_coefficients_by_name(new_coeffs) - result = result.add_constraint(new_cns) - - return result, new_iname_to_dim - - new_domain, new_iname_to_dim = realize_duplication(kernel.domain) + from loopy.isl_helpers import duplicate_axes + new_domain = duplicate_axes(kernel.domain, duplicate_inames, new_inames) + new_iname_to_dim = new_domain.get_space().get_var_dict() # }}} diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index de0c1f0f12fad00b678da76e7f55400b85df4144..c10082086f76d06f50f53e384162b85c954bdc7a 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -185,4 +185,46 @@ def static_value_of_pw_aff(pw_aff, constants_only): +def duplicate_axes(basic_set, duplicate_inames, new_inames): + old_to_new = dict((old_iname, new_iname) + for old_iname, new_iname in zip(duplicate_inames, new_inames)) + + start_idx = basic_set.dim(dim_type.set) + result = basic_set.insert_dims( + dim_type.set, start_idx, + len(duplicate_inames)) + + old_var_dict = basic_set.get_space().get_var_dict() + dup_iname_dims = [old_var_dict[iname] for iname in duplicate_inames] + + for i, iname in enumerate(new_inames): + new_idx = start_idx+i + result = result.set_dim_name( + dim_type.set, new_idx, iname) + + set_bs, = basic_set.get_basic_sets() + + for cns in set_bs.get_constraints(): + if any(cns.involves_dims(*dim+(1,)) for dim in dup_iname_dims): + assert not cns.is_div_constraint() + if cns.is_equality(): + new_cns = cns.equality_alloc(result.get_space()) + else: + new_cns = cns.inequality_alloc(result.get_space()) + + new_coeffs = {} + for key, val in cns.get_coefficients_by_name().iteritems(): + if key in old_to_new: + new_coeffs[old_to_new[key]] = val + else: + new_coeffs[key] = val + + new_cns = new_cns.set_coefficients_by_name(new_coeffs) + result = result.add_constraint(new_cns) + + return result + + + + # vim: foldmethod=marker