diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py index 61a825fb9fef8c4d847e3e8f1310814e56e13a0a..7cc381f11d1239cba5656a9dc7a04cddaa14a368 100644 --- a/loopy/codegen/bounds.py +++ b/loopy/codegen/bounds.py @@ -27,30 +27,24 @@ import islpy as isl from islpy import dim_type -# {{{ bounds check generator +# {{{ approximate, convex bounds check generator -def get_bounds_checks(domain, check_inames, implemented_domain, - overapproximate): +def get_approximate_convex_bounds_checks(domain, check_inames, implemented_domain): if isinstance(domain, isl.BasicSet): domain = isl.Set.from_basic_set(domain) domain = domain.remove_redundancies() result = domain.eliminate_except(check_inames, [dim_type.set]) - if overapproximate: - # This is ok, because we're really looking for the - # projection, with no remaining constraints from - # the eliminated variables. - result = result.remove_divs() - else: - result = result.compute_divs() + # This is ok, because we're really looking for the + # projection, with no remaining constraints from + # the eliminated variables. + result = result.remove_divs() result, implemented_domain = isl.align_two(result, implemented_domain) result = result.gist(implemented_domain) - if overapproximate: - result = result.remove_divs() - else: - result = result.compute_divs() + # (see above) + result = result.remove_divs() from loopy.isl_helpers import convexify result = convexify(result) diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py index 3378ed81ee56f97cc11f8f8998aeb67221061633..6964b344427ad8e7ca967f91715fe95f0bbf91da 100644 --- a/loopy/codegen/control.py +++ b/loopy/codegen/control.py @@ -301,13 +301,11 @@ def build_loop_nest(codegen_state, schedule_index): domain = isl.align_spaces( self.kernel.get_inames_domain(check_inames), self.impl_domain, obj_bigger_ok=True) - from loopy.codegen.bounds import get_bounds_checks - return get_bounds_checks(domain, - check_inames, self.impl_domain, - - # Each instruction individually gets its bounds checks, - # so we can safely overapproximate here. - overapproximate=True) + from loopy.codegen.bounds import get_approximate_convex_bounds_checks + # Each instruction individually gets its bounds checks, + # so we can safely overapproximate here. + return get_approximate_convex_bounds_checks(domain, + check_inames, self.impl_domain) def build_insn_group(sched_index_info_entries, codegen_state, done_group_lengths=set()): @@ -451,13 +449,13 @@ def build_loop_nest(codegen_state, schedule_index): # gen_code returns a list if bounds_checks or pred_checks: - from loopy.symbolic import constraint_to_expr + from loopy.symbolic import constraint_to_cond_expr prev_gen_code = gen_code def gen_code(inner_codegen_state): condition_exprs = [ - constraint_to_expr(cns) + constraint_to_cond_expr(cns) for cns in bounds_checks] + [ pred_chk for pred_chk in pred_checks] diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py index 140ec644731d570fac2e793f0c4e5ea004d165e6..c490abb6ed1635c135fc77468f27cd833b1d57b2 100644 --- a/loopy/codegen/instruction.py +++ b/loopy/codegen/instruction.py @@ -27,6 +27,7 @@ THE SOFTWARE. from six.moves import range import islpy as isl +dim_type = isl.dim_type from loopy.codegen import Unvectorizable from loopy.codegen.result import CodeGenerationResult from pymbolic.mapper.stringifier import PREC_NONE @@ -34,24 +35,27 @@ from pymbolic.mapper.stringifier import PREC_NONE def to_codegen_result( codegen_state, insn_id, domain, check_inames, required_preds, ast): - from loopy.codegen.bounds import get_bounds_checks - from loopy.symbolic import constraint_to_expr - - bounds_checks = get_bounds_checks( - domain, check_inames, - codegen_state.implemented_domain, overapproximate=False) - bounds_check_set = isl.Set.universe(domain.get_space()) \ - .add_constraints(bounds_checks) - bounds_check_set, new_implemented_domain = isl.align_two( - bounds_check_set, codegen_state.implemented_domain) - new_implemented_domain = new_implemented_domain & bounds_check_set - - if bounds_check_set.is_empty(): + # {{{ get bounds check + + chk_domain = isl.Set.from_basic_set(domain) + chk_domain = chk_domain.remove_redundancies() + chk_domain = chk_domain.eliminate_except(check_inames, [dim_type.set]) + + chk_domain, implemented_domain = isl.align_two( + chk_domain, codegen_state.implemented_domain) + chk_domain = chk_domain.gist(implemented_domain) + + # }}} + + new_implemented_domain = implemented_domain & chk_domain + + if chk_domain.is_empty(): return None - condition_exprs = [ - constraint_to_expr(cns) - for cns in bounds_checks] + condition_exprs = [] + if not chk_domain.plain_is_universe(): + from loopy.symbolic import set_to_cond_expr + condition_exprs.append(set_to_cond_expr(chk_domain)) condition_exprs.extend( required_preds - codegen_state.implemented_predicates) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index c90a2b9f17f230a66c387a33e71c9006288ec641..b1743cb1af927fe8a68d8566ea61ce4e511c5e1a 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1140,25 +1140,7 @@ def pw_aff_to_expr(pw_aff, int_ok=False): pieces = pw_aff.get_pieces() last_expr = aff_to_expr(pieces[-1][1]) - # {{{ make exprs from set constraints - - from pymbolic.primitives import LogicalAnd, LogicalOr - - def set_to_expr(isl_set): - constrs = [] - for isl_basicset in isl_set.get_basic_sets(): - constrs.append(basic_set_to_expr(isl_basicset)) - return LogicalOr(tuple(constrs)) - - def basic_set_to_expr(isl_basicset): - constrs = [] - for constr in isl_basicset.get_constraints(): - constrs.append(constraint_to_expr(constr)) - return LogicalAnd(tuple(constrs)) - - # }}} - - pairs = [(set_to_expr(constr_set), aff_to_expr(aff)) + pairs = [(set_to_cond_expr(constr_set), aff_to_expr(aff)) for constr_set, aff in pieces[:-1]] from pymbolic.primitives import If @@ -1274,7 +1256,7 @@ def simplify_using_aff(kernel, expr): # }}} -# {{{ expression <-> constraint conversion +# {{{ expression/set <-> constraint conversion def eq_constraint_from_expr(space, expr): return isl.Constraint.equality_from_aff(aff_from_expr(space, expr)) @@ -1284,7 +1266,7 @@ def ineq_constraint_from_expr(space, expr): return isl.Constraint.inequality_from_aff(aff_from_expr(space, expr)) -def constraint_to_expr(cns): +def constraint_to_cond_expr(cns): # Looks like this is ok after all--get_aff() performs some magic. # Not entirely sure though... FIXME # @@ -1303,6 +1285,39 @@ def constraint_to_expr(cns): # }}} +# {{{ set_to_cond_expr + +def basic_set_to_cond_expr(isl_basicset): + constrs = [] + for constr in isl_basicset.get_constraints(): + constrs.append(constraint_to_cond_expr(constr)) + + if len(constrs) == 0: + raise ValueError("may not be called on universe") + elif len(constrs) == 1: + constr, = constrs + return constr + else: + return p.LogicalAnd(tuple(constrs)) + + +def set_to_cond_expr(isl_set): + conjs = [] + for isl_basicset in isl_set.get_basic_sets(): + conjs.append(basic_set_to_cond_expr(isl_basicset)) + + if len(conjs) == 0: + raise ValueError("may not be called on universe") + elif len(conjs) == 1: + conj, = conjs + return conj + else: + return p.LogicalOr(tuple(conjs)) + + +# }}} + + # {{{ Reduction callback mapper class ReductionCallbackMapper(IdentityMapper):