diff --git a/loopy/precompute.py b/loopy/precompute.py index 11b1396f15b4f9dc440ee75480a1d25fbc1e091a..ae973f98c1de87e2821575f3c65c03c989f696fc 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -1,4 +1,4 @@ -from __future__ import division, absolute_import +from __future__ import division, absolute_import, print_function import six from six.moves import range, zip @@ -25,6 +25,7 @@ THE SOFTWARE. """ +import islpy as isl from loopy.symbolic import (get_dependencies, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext) @@ -450,21 +451,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, if precompute_inames is not None: preexisting_precompute_inames = ( set(precompute_inames) & kernel.all_inames()) - - if ( - preexisting_precompute_inames - and - len(preexisting_precompute_inames) < len(precompute_inames)): - raise LoopyError("some (but not all) of the inames in " - "precompute_inames already exist. existing: %s non-existing: %s" - % ( - preexisting_precompute_inames, - set(precompute_inames) - preexisting_precompute_inames)) - - precompute_inames_already_exist = bool(preexisting_precompute_inames) - else: - precompute_inames_already_exist = False + preexisting_precompute_inames = set() # }}} @@ -483,20 +471,22 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, old_name = saxis name = "%s_%s" % (c_subst_name, old_name) - if precompute_inames is not None and i < len(precompute_inames): + if (precompute_inames is not None + and i < len(precompute_inames) + and precompute_inames[i]): name = precompute_inames[i] tag_lookup_saxis = name - if (not precompute_inames_already_exist + if (name not in preexisting_precompute_inames and var_name_gen.is_name_conflicting(name)): raise RuntimeError("new storage axis name '%s' " "conflicts with existing name" % name) - - if not precompute_inames_already_exist: + else: name = var_name_gen(name) storage_axis_names.append(name) - new_iname_to_tag[name] = storage_axis_to_tag.get( - tag_lookup_saxis, default_tag) + if name not in preexisting_precompute_inames: + new_iname_to_tag[name] = storage_axis_to_tag.get( + tag_lookup_saxis, default_tag) prior_storage_axis_name_dict[name] = old_name @@ -522,9 +512,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, if storage_axis_names: # {{{ find domain to be changed - change_inames = expanding_inames - if precompute_inames_already_exist: - change_inames = change_inames | preexisting_precompute_inames + change_inames = expanding_inames | preexisting_precompute_inames from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, change_inames) @@ -551,40 +539,105 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, else: del new_iname_to_tag[saxis] - if not precompute_inames_already_exist: - new_kernel_domains = domch.get_domains_with( - abm.augment_domain_with_sweep( - domch.domain, non1_storage_axis_names, - boxify_sweep=fetch_bounding_box)) - else: - check_domain = domch.domain + if saxis in preexisting_precompute_inames: + raise LoopyError("precompute axis %d (1-based) was " + "eliminated as " + "having length 1 but also mapped to existing " + "iname '%s'" % (i+1, saxis)) + + mod_domain = domch.domain + + # {{{ modify the domain, taking into account preexisting inames + + # inames may already exist in mod_domain, add them primed to start + primed_non1_saxis_names = [ + iname+"'" for iname in non1_storage_axis_names] - # {{{ check the domain the preexisting inames' domain + mod_domain = abm.augment_domain_with_sweep( + domch.domain, primed_non1_saxis_names, + boxify_sweep=fetch_bounding_box) - # inames already exist in check_domain, add them primed - primed_non1_saxis_names = [ - iname+"'" for iname in non1_storage_axis_names] + check_domain = mod_domain + + for i, saxis in enumerate(non1_storage_axis_names): + var_dict = mod_domain.get_var_dict(isl.dim_type.set) + + if saxis in preexisting_precompute_inames: + # add equality constraint between existing and new variable + + dt, dim_idx = var_dict[saxis] + saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) + + dt, dim_idx = var_dict[primed_non1_saxis_names[i]] + new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) + + mod_domain = mod_domain.add_constraint( + isl.Constraint.inequality_from_aff(new_var_aff - saxis_aff)) + + # project out the new one + mod_domain = mod_domain.project_out(dt, dim_idx, 1) + + else: + # remove the prime from the new variable + dt, dim_idx = var_dict[primed_non1_saxis_names[i]] + mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis) - check_domain = abm.augment_domain_with_sweep( - check_domain, primed_non1_saxis_names, - boxify_sweep=fetch_bounding_box) + # {{{ check that we got the desired domain - # project out the original copies - from loopy.isl_helpers import project_out - check_domain = project_out(check_domain, non1_storage_axis_names) + check_domain = check_domain.project_out_except( + primed_non1_saxis_names, [isl.dim_type.set]) - for iname in non1_storage_axis_names: - var_dict = check_domain.get_var_dict() - dt, dim_idx = var_dict[iname+"'"] - check_domain = check_domain.set_dim_name(dt, dim_idx, iname) + mod_check_domain = mod_domain - if not (check_domain <= domch.domain and domch.domain <= check_domain): - raise LoopyError("domain of preexisting inames does not match " - "domain needed for precompute") + # re-add the prime from the new variable + var_dict = mod_check_domain.get_var_dict(isl.dim_type.set) - # }}} + for saxis in non1_storage_axis_names: + dt, dim_idx = var_dict[saxis] + mod_check_domain = mod_check_domain.set_dim_name(dt, dim_idx, saxis+"'") + + mod_check_domain = mod_check_domain.project_out_except( + primed_non1_saxis_names, [isl.dim_type.set]) + + mod_check_domain, check_domain = isl.align_two( + mod_check_domain, check_domain) + + # The modified domain can't get bigger by adding constraints + assert mod_check_domain <= check_domain + + if not check_domain <= mod_check_domain: + print(check_domain) + print(mod_check_domain) + raise LoopyError("domain of preexisting inames does not match " + "domain needed for precompute") + + # }}} + + # {{{ check that we didn't shrink the original domain + + # project out the new names from the modified domain + orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set)) + mod_check_domain = mod_domain.project_out_except( + orig_domain_inames, [isl.dim_type.set]) + + check_domain = domch.domain + + mod_check_domain, check_domain = isl.align_two( + mod_check_domain, check_domain) + + # The modified domain can't get bigger by adding constraints + assert mod_check_domain <= check_domain + + if not check_domain <= mod_check_domain: + print(check_domain) + print(mod_check_domain) + raise LoopyError("original domain got shrunk by applying the precompute") + + # }}} + + # }}} - new_kernel_domains = domch.get_domains_with(domch.domain) + new_kernel_domains = domch.get_domains_with(mod_domain) else: # leave kernel domains unchanged diff --git a/test/test_fortran.py b/test/test_fortran.py index 4117b80a27b243dee1db94b5a0bb2b83b2ec8d49..c31c370076b681cb0593f38b6a4d92479541b872 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -453,6 +453,49 @@ def test_parse_and_fuse_two_kernels(): knl, = lp.parse_transformed_fortran(fortran_src) +def test_precompute_some_exist(ctx_factory): + fortran_src = """ + subroutine dgemm(m,n,l,a,b,c) + implicit none + real*8 a(m,l),b(l,n),c(m,n) + integer m,n,k,i,j,l + + do j = 1,n + do i = 1,m + do k = 1,l + c(i,j) = c(i,j) + b(k,j)*a(i,k) + end do + end do + end do + end subroutine + """ + + knl, = lp.parse_fortran(fortran_src) + + assert len(knl.domains) == 1 + + knl = lp.split_iname(knl, "i", 8, + outer_tag="g.0", inner_tag="l.1") + knl = lp.split_iname(knl, "j", 8, + outer_tag="g.1", inner_tag="l.0") + knl = lp.split_iname(knl, "k", 8) + knl = lp.assume(knl, "n mod 8 = 0") + knl = lp.assume(knl, "m mod 8 = 0") + knl = lp.assume(knl, "l mod 8 = 0") + + knl = lp.extract_subst(knl, "a_acc", "a[i1,i2]", parameters="i1, i2") + knl = lp.extract_subst(knl, "b_acc", "b[i1,i2]", parameters="i1, i2") + knl = lp.precompute(knl, "a_acc", "k_inner,i_inner", + precompute_inames="ktemp,itemp") + knl = lp.precompute(knl, "b_acc", "j_inner,k_inner", + precompute_inames="itemp,k2temp") + + ref_knl = knl + + ctx = ctx_factory() + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=128, m=128, l=128)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])