From 55dbbd6d8dc7c554aea8396a74319e529302553e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 13 Mar 2017 19:22:25 -0500 Subject: [PATCH] Add simplify_pw_aff to generate simpler loop bounds --- loopy/codegen/loop.py | 7 ++++++- loopy/isl_helpers.py | 49 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index a33446204..0110a0609 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -465,12 +465,17 @@ def generate_sequential_loop_dim_code(codegen_state, sched_index): else: inner_ast = inner.current_ast(codegen_state) + + from loopy.isl_helpers import simplify_pw_aff + result.append( inner.with_new_ast( codegen_state, astb.emit_sequential_loop( codegen_state, loop_iname, kernel.index_dtype, - pw_aff_to_expr(lbound), pw_aff_to_expr(ubound), inner_ast))) + pw_aff_to_expr(simplify_pw_aff(lbound, kernel.assumptions)), + pw_aff_to_expr(simplify_pw_aff(ubound, kernel.assumptions)), + inner_ast))) return merge_codegen_results(codegen_state, result) diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 602830de3..0ebe90fbc 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -142,6 +142,55 @@ def iname_rel_aff(space, iname, rel, aff): raise ValueError("unknown value of 'rel': %s" % rel) +# {{{ simplify_pw_aff + +def simplify_pw_aff(pw_aff, context=None): + if context is not None: + pw_aff = pw_aff.gist_params(context) + + old_pw_aff = pw_aff + + while True: + restart = False + did_something = False + + pieces = pw_aff.get_pieces() + for i, (dom_i, aff_i) in enumerate(pieces): + for j, (dom_j, aff_j) in enumerate(pieces): + if i == j: + continue + + if aff_i.gist(dom_j).is_equal(aff_j): + # aff_i is sufficient to conver aff_j, eliminate aff_j + new_pieces = pieces[:] + if i < j: + new_pieces.pop(j) + new_pieces.pop(i) + else: + new_pieces.pop(i) + new_pieces.pop(j) + + pw_aff = isl.PwAff.alloc(dom_i | dom_j, aff_i) + for dom, aff in new_pieces: + pw_aff = pw_aff.union_max(isl.PwAff.alloc(dom, aff)) + + restart = True + did_something = True + break + + if restart: + break + + if not did_something: + break + + assert pw_aff.get_aggregate_domain() <= pw_aff.eq_set(old_pw_aff) + + return pw_aff + +# }}} + + # {{{ static_*_of_pw_aff def static_extremum_of_pw_aff(pw_aff, constants_only, set_method, what, context): -- GitLab