diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index a334462049634fff1e3137ffd09acd3ef254bb51..0110a06095fa0bd690045f050136027d7bed3a28 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -465,12 +465,17 @@ def generate_sequential_loop_dim_code(codegen_state, sched_index): else: inner_ast = inner.current_ast(codegen_state) + + from loopy.isl_helpers import simplify_pw_aff + result.append( inner.with_new_ast( codegen_state, astb.emit_sequential_loop( codegen_state, loop_iname, kernel.index_dtype, - pw_aff_to_expr(lbound), pw_aff_to_expr(ubound), inner_ast))) + pw_aff_to_expr(simplify_pw_aff(lbound, kernel.assumptions)), + pw_aff_to_expr(simplify_pw_aff(ubound, kernel.assumptions)), + inner_ast))) return merge_codegen_results(codegen_state, result) diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 602830de38e457c5ff4a55d7685dc346a7b4de35..0ebe90fbca0d31c05eaee64321e2b73709292331 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -142,6 +142,55 @@ def iname_rel_aff(space, iname, rel, aff): raise ValueError("unknown value of 'rel': %s" % rel) +# {{{ simplify_pw_aff + +def simplify_pw_aff(pw_aff, context=None): + if context is not None: + pw_aff = pw_aff.gist_params(context) + + old_pw_aff = pw_aff + + while True: + restart = False + did_something = False + + pieces = pw_aff.get_pieces() + for i, (dom_i, aff_i) in enumerate(pieces): + for j, (dom_j, aff_j) in enumerate(pieces): + if i == j: + continue + + if aff_i.gist(dom_j).is_equal(aff_j): + # aff_i is sufficient to conver aff_j, eliminate aff_j + new_pieces = pieces[:] + if i < j: + new_pieces.pop(j) + new_pieces.pop(i) + else: + new_pieces.pop(i) + new_pieces.pop(j) + + pw_aff = isl.PwAff.alloc(dom_i | dom_j, aff_i) + for dom, aff in new_pieces: + pw_aff = pw_aff.union_max(isl.PwAff.alloc(dom, aff)) + + restart = True + did_something = True + break + + if restart: + break + + if not did_something: + break + + assert pw_aff.get_aggregate_domain() <= pw_aff.eq_set(old_pw_aff) + + return pw_aff + +# }}} + + # {{{ static_*_of_pw_aff def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context):