diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index f27458fec99ddabc97df1fa35fcf384000f8c73f..dcea94690c2cca630b40a870f777d884715f1dd4 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -320,11 +320,11 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state): static_lbound = static_min_of_pw_aff( kernel.cache_manager.dim_min( dom_and_slab, loop_iname_idx).coalesce(), - constants_only=False) + constants_only=False, prefer_constants=False) static_ubound = static_max_of_pw_aff( kernel.cache_manager.dim_max( dom_and_slab, loop_iname_idx).coalesce(), - constants_only=False) + constants_only=False, prefer_constants=False) # }}} diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index fdf38b768269dfc8e86c76de9fd1dcf05e4e86e6..53e6309c7cf8eb787f7e6d6d47f65c9da716aad9 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -137,7 +137,8 @@ def iname_rel_aff(space, iname, rel, aff): raise ValueError("unknown value of 'rel': %s" % rel) -def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context): +def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context, + prefer_constants): if context is not None: context = isl.align_spaces(context, pw_aff.get_domain_space(), obj_bigger_ok=True).params() @@ -163,12 +164,21 @@ def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context) # put constant bounds with unbounded validity first # FIXME: Heuristi-hack. - order = [ - (True, False), # constant, unbounded validity - (False, False), # nonconstant, unbounded validity - (True, True), # constant, bounded validity - (False, True), # nonconstant, bounded validity - ] + if prefer_constants: + order = [ + (True, False), # constant, unbounded validity + (False, False), # nonconstant, unbounded validity + (True, True), # constant, bounded validity + (False, True), # nonconstant, bounded validity + ] + else: + order = [ + (False, False), # nonconstant, unbounded validity + (True, False), # constant, unbounded validity + (False, True), # nonconstant, bounded validity + (True, True), # constant, bounded validity + ] + pieces = flatten([ [(set, aff) for set, aff in pieces if aff.is_cst() == want_is_constant @@ -199,19 +209,22 @@ def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context) % (what, pw_aff)) -def static_min_of_pw_aff(pw_aff, constants_only, context=None): +def static_min_of_pw_aff(pw_aff, constants_only, context=None, + prefer_constants=True): return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.ge_set, - "minimum", context) + "minimum", context, prefer_constants) -def static_max_of_pw_aff(pw_aff, constants_only, context=None): +def static_max_of_pw_aff(pw_aff, constants_only, context=None, + prefer_constants=True): return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.le_set, - "maximum", context) + "maximum", context, prefer_constants) -def static_value_of_pw_aff(pw_aff, constants_only, context=None): +def static_value_of_pw_aff(pw_aff, constants_only, context=None, + prefer_constants=True): return static_extremum_of_pw_aff(pw_aff, constants_only, isl.PwAff.eq_set, - "value", context) + "value", context, prefer_constants) def duplicate_axes(isl_obj, duplicate_inames, new_inames):