diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index 1b4e057589360d2c663e7cc0fdddf40f093f2c06..3e18e997c517749a61eb6fc616806e07d682137c 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -129,19 +129,15 @@ def generate_unroll_loop(kernel, sched_index, codegen_state): tag = kernel.iname_to_tag.get(iname) bounds = kernel.get_iname_bounds(iname) - from loopy.isl_helpers import static_max_of_pw_aff + from loopy.isl_helpers import ( + static_max_of_pw_aff, static_value_of_pw_aff) from loopy.symbolic import pw_aff_to_expr length = int(pw_aff_to_expr( static_max_of_pw_aff(bounds.size, constants_only=True))) - lower_bound_pw_aff_pieces = bounds.lower_bound_pw_aff.coalesce().get_pieces() - - if len(lower_bound_pw_aff_pieces) > 1: - raise NotImplementedError("lower bound for unroll of '%s'" - "needs conditional/has more than one piece:\n%s" % ( - iname, "\n".join(str(piece) for piece in lower_bound_pw_aff_pieces))) - - (_, lower_bound_aff), = lower_bound_pw_aff_pieces + lower_bound_aff = static_value_of_pw_aff( + bounds.lower_bound_pw_aff.coalesce(), + constants_only=False) from loopy.kernel import UnrollTag if isinstance(tag, UnrollTag): diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index d1d42ffcfcdbd03e86389c39c2313c8684ff61d8..84cebabcbf346a499d31602807af65872c6c5aa2 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -156,6 +156,7 @@ def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what): return pieces[0][1] agg_domain = pw_aff.get_aggregate_domain() + for set, candidate_aff in pieces: if constants_only and not candidate_aff.is_cst(): continue @@ -177,6 +178,10 @@ def static_max_of_pw_aff(pw_aff, constants_only): return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.le_set, "maximum") +def static_value_of_pw_aff(pw_aff, constants_only): + return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.eq_set, + "value") +