Skip to content
Snippets Groups Projects
Commit e2e4372a authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Centralize convexification. Better code for get_bounds_checks.

parent 00791241
No related branches found
No related tags found
No related merge requests found
......@@ -97,14 +97,14 @@ Future ideas
- Try, fix indirect addressing
- Use gists (why do disjoint sets arise?)
- Nested slab decomposition (in conjunction with conditional hoisting) could
generate nested conditional code.
Dealt with
^^^^^^^^^^
- Use gists (why do disjoint sets arise?)
- Automatically verify that all array access is within bounds.
- : (as in, Matlab full-slice) in prefetches
......
......@@ -105,43 +105,36 @@ def constraint_to_code(ccm, cns):
from loopy.symbolic import constraint_to_expr
return "%s %s 0" % (ccm(constraint_to_expr(cns), 'i'), comp_op)
def filter_necessary_constraints(implemented_domain, constraints):
return [cns
for cns in constraints
if not implemented_domain.is_subset(
isl.align_spaces(
isl.BasicSet.universe(cns.get_space()).add_constraint(cns),
implemented_domain))]
def generate_bounds_checks(domain, check_inames, implemented_domain):
"""Will not overapproximate if check_inames consists of all inames in the domain."""
"""Will not overapproximate."""
if len(check_inames) == domain.dim(dim_type.set):
assert check_inames == frozenset(domain.get_var_names(dim_type.set))
else:
domain = (domain
.eliminate_except(check_inames, [dim_type.set])
.remove_divs())
domain = (domain
.eliminate_except(check_inames, [dim_type.set])
.compute_divs())
if isinstance(domain, isl.Set):
bsets = domain.get_basic_sets()
if len(bsets) == 1:
domain_bset, = bsets
else:
if len(bsets) != 1:
domain = domain.coalesce()
bsets = domain.get_basic_sets()
if len(bsets) == 1:
if len(bsets) != 1:
raise RuntimeError("domain of inames '%s' projected onto '%s' "
"did not reduce to a single conjunction"
% (", ".join(domain.get_var_names(dim_type.set)),
check_inames))
domain, = bsets
else:
domain_bset = domain
domain = domain
domain = domain.remove_redundancies()
domain = isl.Set.from_basic_set(domain)
domain = isl.align_spaces(domain, implemented_domain)
domain_bset = domain_bset.remove_redundancies()
result = domain.gist(implemented_domain)
return filter_necessary_constraints(
implemented_domain, domain_bset.get_constraints())
from loopy.isl_helpers import convexify
return convexify(result).get_constraints()
def wrap_in_bounds_checks(ccm, domain, check_inames, implemented_domain, stmt):
bounds_checks = generate_bounds_checks(
......
......@@ -650,33 +650,8 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
storage_axis_names, storage_axis_sources,
sweep_inames, invocation_descriptors)
# {{{ try a few ways to get new_domain to be convex
if len(new_domain.get_basic_sets()) > 1:
hull_new_domain = new_domain.simple_hull()
if isl.Set.from_basic_set(hull_new_domain) <= new_domain:
new_domain = hull_new_domain
new_domain = new_domain.coalesce()
if len(new_domain.get_basic_sets()) > 1:
hull_new_domain = new_domain.simple_hull()
if isl.Set.from_basic_set(hull_new_domain) <= new_domain:
new_domain = hull_new_domain
if isinstance(new_domain, isl.Set):
dom_bsets = new_domain.get_basic_sets()
if len(dom_bsets) > 1:
print "PIECES:"
for dbs in dom_bsets:
print " %s" % (isl.Set.from_basic_set(dbs).gist(new_domain))
raise NotImplementedError("Substitution '%s' yielded a non-convex footprint"
% subst_name)
new_domain, = dom_bsets
# }}}
from loopy.isl_helpers import convexify
new_domain = convexify(new_domain)
# {{{ set up compute insn
......
......@@ -277,6 +277,44 @@ def is_nonnegative(expr, over_set):
def convexify(domain):
"""Try a few ways to get *domain* to be a BasicSet, i.e.
explicitly convex.
"""
if isinstance(domain, isl.BasicSet):
return domain
dom_bsets = domain.get_basic_sets()
if len(dom_bsets) == 1:
domain, = dom_bsets
return domain
hull_domain = domain.simple_hull()
if isl.Set.from_basic_set(hull_domain) <= domain:
return hull_domain
domain = domain.coalesce()
dom_bsets = domain.get_basic_sets()
if len(domain.get_basic_sets()) == 1:
domain, = dom_bsets
return domain
hull_domain = domain.simple_hull()
if isl.Set.from_basic_set(hull_domain) <= domain:
return hull_domain
dom_bsets = domain.get_basic_sets()
assert len(dom_bsets) > 1
print "PIECES:"
for dbs in dom_bsets:
print " %s" % (isl.Set.from_basic_set(dbs).gist(domain))
raise NotImplementedError("Could not find convex representation of set")
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment