diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 0ebe90fbca0d31c05eaee64321e2b73709292331..36fbb49f4bb77c959877fb0bd21e1de6fb49c74b 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -594,6 +594,10 @@ def get_simple_strides(bset, key_by="name"): """ result = {} + comp_div_set_pieces = convexify(bset.compute_divs()).get_basic_sets() + assert len(comp_div_set_pieces) == 1 + bset, = comp_div_set_pieces + lspace = bset.get_local_space() for idiv in range(lspace.dim(dim_type.div)): div = lspace.get_div(idiv) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index a19e06ecdf7c9966501ebb9600ea4e01614363f4..5b208d0a43ef601411fa20bf7d23c942c686210e 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -681,12 +681,18 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, dt, dim_idx = var_dict[primed_non1_saxis_names[i]] mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis) + def add_assumptions(d): + assumption_non_param = isl.BasicSet.from_params(kernel.assumptions) + assumptions, domain = isl.align_two(assumption_non_param, d) + return d & assumptions + # {{{ check that we got the desired domain - check_domain = check_domain.project_out_except( - primed_non1_saxis_names, [isl.dim_type.set]) + check_domain = add_assumptions( + check_domain.project_out_except( + primed_non1_saxis_names, [isl.dim_type.set])) - mod_check_domain = mod_domain + mod_check_domain = add_assumptions(mod_domain) # re-add the prime from the new variable var_dict = mod_check_domain.get_var_dict(isl.dim_type.set) @@ -716,10 +722,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # project out the new names from the modified domain orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set)) - mod_check_domain = mod_domain.project_out_except( - orig_domain_inames, [isl.dim_type.set]) + mod_check_domain = add_assumptions( + mod_domain.project_out_except( + orig_domain_inames, [isl.dim_type.set])) - check_domain = domch.domain + check_domain = add_assumptions(domch.domain) mod_check_domain, check_domain = isl.align_two( mod_check_domain, check_domain)