diff --git a/loopy/loop.py b/loopy/loop.py index 66d413987466e98e5a188df93cad49f5584cd3f7..a2793c196991f24aff7e007ec8cb930e171dfa74 100644 --- a/loopy/loop.py +++ b/loopy/loop.py @@ -32,7 +32,8 @@ def potential_loop_nest_map(kernel): """Returns a dictionary mapping inames to other inames that *could* be nested around them. - :seealso: :func:`loopy.schedule.loop_nest_map` + * :seealso: :func:`loopy.schedule.loop_nest_map` + * :seealso: :func:`loopy.schedule.find_loop_nest_around_map` """ result = {} @@ -65,6 +66,8 @@ def fuse_loop_domains(kernel): parents_per_domain = kernel.parents_per_domain() all_parents_per_domain = kernel.all_parents_per_domain() + iname_to_insns = kernel.iname_to_insns() + new_domains = None for inner_iname, outer_inames in six.iteritems(lnm): @@ -77,6 +80,12 @@ def fuse_loop_domains(kernel): if inner_domain_idx == outer_domain_idx: break + if iname_to_insns[inner_iname] != iname_to_insns[outer_iname]: + # The two inames are imperfectly nested. Domain fusion + # might be invalid when the inner loop is empty, leading to + # the outer loop also being empty. + continue + if ( outer_domain_idx in all_parents_per_domain[inner_domain_idx] and not diff --git a/test/test_fortran.py b/test/test_fortran.py index 496b470dea2131890e1c4d113226875a8f8b74a9..902c2d1b7aa0b9436ed51e01dc687837edf8f4ba 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -517,12 +517,32 @@ def test_fortran_subroutines(ctx_factory): call twice(n, a(1:n, i)) call twice(n, a(i, 1:n)) + end subroutine + """ + prg = lp.parse_fortran(fortran_src) + print(lp.generate_code_v2(prg).device_code()) +def test_domain_fusion_imperfectly_nested(): + fortran_src = """ + subroutine imperfect(n, m, a, b) + implicit none + integer i, j, n, m + real a(n), b(n,n) + + do i=1, n + a(i) = i + do j=1, m + b(i,j) = i*j + end do + end do end subroutine """ + prg = lp.parse_fortran(fortran_src) - print(lp.generate_code_v2(prg).device_code()) + # If n > 0 and m == 0, a single domain would be empty, + # leading (incorrectly) to no assignments to 'a'. + assert len(prg["imperfect"].subkernel.domains) > 1 if __name__ == "__main__":