diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index 8ac963835ec12702f2010806d1d49062422318a2..ad80475c1d27f67b3df8a885f60dd96ff28efe6a 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -386,29 +386,41 @@ def generate_sequential_loop_dim_code(codegen_state, sched_index): _, loop_iname_idx = dom_and_slab.get_var_dict()[loop_iname] + impl_domain = isl.align_spaces( + codegen_state.implemented_domain, + dom_and_slab, + obj_bigger_ok=True, + across_dim_types=True + ).params() + lbound = ( kernel.cache_manager.dim_min( dom_and_slab, loop_iname_idx) .gist(kernel.assumptions) - .gist(dom_and_slab.params()) + .gist(impl_domain) .coalesce()) ubound = ( kernel.cache_manager.dim_max( dom_and_slab, loop_iname_idx) .gist(kernel.assumptions) - .gist(dom_and_slab.params()) + .gist(impl_domain) .coalesce()) # }}} # {{{ find implemented loop, build inner code - from loopy.isl_helpers import make_loop_bounds_from_pwaffs + from loopy.symbolic import pw_aff_to_pw_aff_implemented_by_expr + impl_lbound = pw_aff_to_pw_aff_implemented_by_expr(lbound) + impl_ubound = pw_aff_to_pw_aff_implemented_by_expr(ubound) # impl_loop may be overapproximated + from loopy.isl_helpers import make_loop_bounds_from_pwaffs impl_loop = make_loop_bounds_from_pwaffs( dom_and_slab.space, - loop_iname, lbound, ubound) + loop_iname, + impl_lbound, + impl_ubound) for iname in moved_inames: dt, idx = impl_loop.get_var_dict()[iname] @@ -431,13 +443,9 @@ def generate_sequential_loop_dim_code(codegen_state, sched_index): astb = codegen_state.ast_builder - zero = isl.PwAff.zero_on_domain( - isl.LocalSpace.from_space( - lbound.get_space()).domain()) - from loopy.symbolic import pw_aff_to_expr - if (ubound - lbound).plain_is_equal(zero): + if ubound.is_equal(lbound): # single-trip, generate just a variable assignment, not a loop inner = merge_codegen_results(codegen_state, [ astb.emit_initializer( diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 430c651589939a1001432bd8db413cb5902b14a6..52fd6e57f92e7f9599a3a0fb4256f97347708303 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1150,6 +1150,23 @@ def pw_aff_to_expr(pw_aff, int_ok=False): return expr + +def pw_aff_to_pw_aff_implemented_by_expr(pw_aff): + pieces = pw_aff.get_pieces() + + rest = isl.Set.universe(pw_aff.space.params()) + aff_set, aff = pieces[0] + impl_pw_aff = isl.PwAff.alloc(aff_set, aff) + rest = rest.intersect_params(aff_set.complement()) + + for aff_set, aff in pieces[1:-1]: + impl_pw_aff = impl_pw_aff.union_max( + isl.PwAff.alloc(aff_set, aff)) + rest = rest.intersect_params(aff_set.complement()) + + _, aff = pieces[-1] + return impl_pw_aff.union_max(isl.PwAff.alloc(rest, aff)).coalesce() + # }}}