diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 53e6309c7cf8eb787f7e6d6d47f65c9da716aad9..b876513f754209f7ff498f4b9243a9ac54cce7f8 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -376,4 +376,12 @@ def simplify_via_aff(expr): expr)) +def project_out(set, inames): + for iname in inames: + var_dict = set.get_var_dict() + dt, dim_idx = var_dict[iname] + set = set.project_out(dt, dim_idx, 1) + + return set + # vim: foldmethod=marker diff --git a/loopy/precompute.py b/loopy/precompute.py index 5675537f1f3915dcce8b68b2d36200bdcaef3f5a..935d6d44040cf56b3ca3bdbbec26e22ec750a05d 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -473,7 +473,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, if precompute_inames is not None and i < len(precompute_inames): name = precompute_inames[i] tag_lookup_saxis = name - if not precompute_inames_already_exist and var_name_gen.is_name_conflicting(name): + if (not precompute_inames_already_exist + and var_name_gen.is_name_conflicting(name)): raise RuntimeError("new storage axis name '%s' " "conflicts with existing name" % name) @@ -543,8 +544,34 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, domch.domain, non1_storage_axis_names, boxify_sweep=fetch_bounding_box)) else: - new_kernel_domains = domch.get_domains_with(domch.domain) + check_domain = domch.domain + + # {{{ check the domain the preexisting inames' domain + + # inames already exist in check_domain, add them primed + primed_non1_saxis_names = [ + iname+"'" for iname in non1_storage_axis_names] + + check_domain = abm.augment_domain_with_sweep( + check_domain, primed_non1_saxis_names, + boxify_sweep=fetch_bounding_box) + + # project out the original copies + from loopy.isl_helpers import project_out + check_domain = project_out(check_domain, non1_storage_axis_names) + for iname in non1_storage_axis_names: + var_dict = check_domain.get_var_dict() + dt, dim_idx = var_dict[iname+"'"] + check_domain = check_domain.set_dim_name(dt, dim_idx, iname) + + if not (check_domain <= domch.domain and domch.domain <= check_domain): + raise LoopyError("domain of preexisting inames does not match " + "domain needed for precompute") + + # }}} + + new_kernel_domains = domch.get_domains_with(domch.domain) else: # leave kernel domains unchanged diff --git a/test/test_loopy.py b/test/test_loopy.py index 5dce1292d9966deff3d145c00610945c39932429..2173347cadaea66a4ec53d78fb8c77abb24cc7bb 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2003,6 +2003,67 @@ def test_generate_c_snippet(): print(lp.generate_body(knl)) +def test_precompute_with_preexisting_inames(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[e,i,j,k]: 0<=e<E and 0<=i,j,k<n}", + """ + result[e,i] = sum(j, D1[i,j]*u[e,j]) + result2[e,i] = sum(k, D2[i,k]*u[e,k]) + """) + + knl = lp.add_and_infer_dtypes(knl, { + "u": np.float32, + "D1": np.float32, + "D2": np.float32, + }) + + knl = lp.fix_parameters(knl, n=13) + + ref_knl = knl + + knl = lp.extract_subst(knl, "D1_subst", "D1[ii,jj]", parameters="ii,jj") + knl = lp.extract_subst(knl, "D2_subst", "D2[ii,jj]", parameters="ii,jj") + + knl = lp.precompute(knl, "D1_subst", "i,j", default_tag="for", + precompute_inames="ii,jj") + knl = lp.precompute(knl, "D2_subst", "i,k", default_tag="for", + precompute_inames="ii,jj") + + knl = lp.set_loop_priority(knl, "ii,jj,e,j,k") + + lp.auto_test_vs_ref( + ref_knl, ctx, knl, + parameters=dict(E=200)) + + +def test_precompute_with_preexisting_inames_fail(): + knl = lp.make_kernel( + "{[e,i,j,k]: 0<=e<E and 0<=i,j<n and 0<=k<2*n}", + """ + result[e,i] = sum(j, D1[i,j]*u[e,j]) + result2[e,i] = sum(k, D2[i,k]*u[e,k]) + """) + + knl = lp.add_and_infer_dtypes(knl, { + "u": np.float32, + "D1": np.float32, + "D2": np.float32, + }) + + knl = lp.fix_parameters(knl, n=13) + + knl = lp.extract_subst(knl, "D1_subst", "D1[ii,jj]", parameters="ii,jj") + knl = lp.extract_subst(knl, "D2_subst", "D2[ii,jj]", parameters="ii,jj") + + knl = lp.precompute(knl, "D1_subst", "i,j", default_tag="for", + precompute_inames="ii,jj") + with pytest.raises(lp.LoopyError): + lp.precompute(knl, "D2_subst", "i,k", default_tag="for", + precompute_inames="ii,jj") + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])