Skip to content
Snippets Groups Projects
Commit dd485bb0 authored by James Stevens's avatar James Stevens
Browse files

Merge remote-tracking branch 'upstream/master'

parents 6ff4c160 ae326839
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ from islpy import dim_type
from loopy.symbolic import (get_dependencies, SubstitutionMapper)
from pymbolic.mapper.substitutor import make_subst_func
from pytools import Record
from pytools import Record, memoize_method
from pymbolic import var
......@@ -64,7 +64,7 @@ def to_parameters_or_project_out(param_inames, set_inames, set):
# {{{ construct storage->sweep map
def build_per_access_storage_to_domain_map(accdesc, domain,
def build_per_access_storage_to_domain_map(storage_axis_exprs, domain,
storage_axis_names,
prime_sweep_inames):
......@@ -91,7 +91,7 @@ def build_per_access_storage_to_domain_map(accdesc, domain,
from loopy.symbolic import aff_from_expr
for saxis, sa_expr in zip(storage_axis_names, accdesc.storage_axis_exprs):
for saxis, sa_expr in zip(storage_axis_names, storage_axis_exprs):
cns = isl.Constraint.equality_from_aff(
aff_from_expr(set_space,
var(saxis+"'") - prime_sweep_inames(sa_expr)))
......@@ -138,7 +138,7 @@ def build_global_storage_to_sweep_map(kernel, access_descriptors,
# build footprint
for accdesc in access_descriptors:
stor2sweep = build_per_access_storage_to_domain_map(
accdesc, domain_dup_sweep,
accdesc.storage_axis_exprs, domain_dup_sweep,
storage_axis_names,
prime_sweep_inames)
......@@ -336,6 +336,11 @@ class ArrayToBufferMap(object):
return convexify(domain)
def is_access_descriptor_in_footprint(self, accdesc):
return self._is_access_descriptor_in_footprint_inner(
tuple(accdesc.storage_axis_exprs))
@memoize_method
def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs):
# Make all inames except the sweep parameters. (The footprint may depend on
# those.) (I.e. only leave sweep inames as out parameters.)
......@@ -347,7 +352,7 @@ class ArrayToBufferMap(object):
set(global_s2s_par_dom.get_var_names(dim_type.param))
& self.kernel.all_inames())
for arg in accdesc.storage_axis_exprs:
for arg in storage_axis_exprs:
arg_inames.update(get_dependencies(arg))
arg_inames = frozenset(arg_inames)
......@@ -363,7 +368,8 @@ class ArrayToBufferMap(object):
usage_domain = usage_domain.set_dim_name(
dim_type.set, i, iname+"'")
stor2sweep = build_per_access_storage_to_domain_map(accdesc,
stor2sweep = build_per_access_storage_to_domain_map(
storage_axis_exprs,
usage_domain, self.storage_axis_names,
self.prime_sweep_inames)
......
......@@ -384,4 +384,129 @@ def obj_involves_variable(obj, var_name):
return False
# {{{ performance tweak for dim_{min,max}: project first
def _runs_in_integer_set(s, max_int=None):
if not s:
return
if max_int is None:
max_int = max(s)
i = 0
while i < max_int:
if i in s:
start = i
i += 1
while i < max_int and i in s:
i += 1
end = i
yield (start, end-start)
else:
i += 1
class TooManyInteractingDims(Exception):
pass
def _find_aff_dims(aff, dim_types_and_gen_dim_types):
result = []
for dt, gen_dt in dim_types_and_gen_dim_types:
for i in range(aff.dim(dt)):
if not aff.get_coefficient_val(dt, i).is_zero():
result.append((gen_dt, i))
result = set(result)
for i in range(aff.dim(dim_type.div)):
if not aff.get_coefficient_val(dim_type.div, i).is_zero():
result.update(_find_aff_dims(
aff.get_div(i),
dim_types_and_gen_dim_types))
return result
def _transitive_closure(graph_dict):
pass
def _find_noninteracting_dims(obj, dt, idx, other_dt, stop_at=6):
if isinstance(obj, isl.BasicSet):
basics = [obj]
elif isinstance(obj, isl.Set):
basics = obj.get_basic_sets()
else:
raise TypeError("unsupported arg type '%s'" % type(obj))
connections = []
for bs in basics:
for c in bs.get_constraints():
conn = _find_aff_dims(
c.get_aff(),
[(dim_type.param, dim_type.param), (dim_type.in_, dim_type.set)])
if len(conn) > 1:
connections.append(conn)
interacting = set([(dt, idx)])
while True:
changed_something = False
# Compute the connected component near (dt, idx) by fixed point iteration
for conn in connections:
prev_len = len(interacting)
overlap = interacting & conn
if overlap:
interacting.update(conn)
if len(interacting) != prev_len:
changed_something = True
if len(interacting) >= stop_at:
raise TooManyInteractingDims()
if not changed_something:
break
return set(range(obj.dim(other_dt))) - set(
idx for dt, idx in interacting
if dt == other_dt)
def _eliminate_noninteracting(obj, dt, idx, other_dt):
obj = obj.compute_divs()
try:
nonint = _find_noninteracting_dims(obj, dt, idx, other_dt)
except TooManyInteractingDims:
return obj
for first, n in _runs_in_integer_set(nonint):
obj = obj.eliminate(other_dt, first, n)
return obj
def dim_min_with_elimination(obj, idx):
obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param)
return obj_elim.dim_min(idx)
def dim_max_with_elimination(obj, idx):
obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param)
return obj_elim.dim_max(idx)
# }}}
# vim: foldmethod=marker
......@@ -251,15 +251,17 @@ class SetOperationCacheManager:
return result
#print op, set.get_dim_name(dim_type.set, args[0])
result = op(*args)
result = op(set, *args)
bucket.append((set, op_name, args, result))
return result
def dim_min(self, set, *args):
return self.op(set, "dim_min", set.dim_min, args)
from loopy.isl_helpers import dim_min_with_elimination
return self.op(set, "dim_min", dim_min_with_elimination, args)
def dim_max(self, set, *args):
return self.op(set, "dim_max", set.dim_max, args)
from loopy.isl_helpers import dim_max_with_elimination
return self.op(set, "dim_max", dim_max_with_elimination, args)
def base_index_and_length(self, set, iname, context=None):
if not isinstance(iname, int):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment