Skip to content
Snippets Groups Projects
Commit 0cb04b55 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'master' of ssh://gitlab.tiker.net/inducer/loopy

parents d81afed5 a26a8207
No related branches found
No related tags found
No related merge requests found
......@@ -1075,6 +1075,21 @@ class LoopKernel(ImmutableRecordWithoutPickling):
# {{{ pretty-printing
@memoize_method
def _get_iname_order_for_printing(self):
try:
from loopy.kernel.tools import get_visual_iname_order_embedding
embedding = get_visual_iname_order_embedding(self)
except ValueError:
from loopy.diagnostic import warn_with_kernel
warn_with_kernel(self,
"iname-order",
"get_visual_iname_order_embedding() could not determine a "
"consistent iname nesting order")
embedding = dict((iname, iname) for iname in self.all_inames())
return embedding
def stringify(self, what=None, with_dependencies=False):
all_what = set([
"name",
......@@ -1228,7 +1243,9 @@ class LoopKernel(ImmutableRecordWithoutPickling):
raise LoopyError("unexpected instruction type: %s"
% type(insn).__name__)
loop_list = ",".join(natsorted(kernel.insn_inames(insn)))
order = self._get_iname_order_for_printing()
loop_list = ",".join(
sorted(kernel.insn_inames(insn), key=lambda iname: order[iname]))
options = [Fore.GREEN+insn.id+Style.RESET_ALL]
if insn.priority:
......
......@@ -1100,6 +1100,120 @@ def guess_var_shape(kernel, var_name):
# }}}
# {{{ loop nest tracker
class SetTrie(object):
"""
Similar to a trie, but uses an unordered sequence as the key.
"""
def __init__(self, children=(), all_items=None):
self.children = dict(children)
# all_items should be shared within a trie.
if all_items is None:
self.all_items = set()
else:
self.all_items = all_items
def descend(self, on_found=lambda prefix: None, prefix=frozenset()):
on_found(prefix)
from six import iteritems
for prefix, child in sorted(
iteritems(self.children),
key=lambda it: sorted(it[0])):
child.descend(on_found, prefix=prefix)
def check_consistent_insert(self, items_to_insert):
if items_to_insert & self.all_items:
raise ValueError("inconsistent nesting")
def add_or_update(self, key):
if len(key) == 0:
return
from six import iteritems
for child_key, child in iteritems(self.children):
common = child_key & key
if common:
break
else:
# Key not found - insert new child
self.check_consistent_insert(key)
self.children[frozenset(key)] = SetTrie(all_items=self.all_items)
self.all_items.update(key)
return
if child_key <= key:
# child is a prefix of key:
child.add_or_update(key - common)
elif key < child_key:
# key is a strict prefix of child:
#
# -[new child]
# |
# [child]
#
del self.children[child_key]
self.children[common] = SetTrie(
children={frozenset(child_key - common): child},
all_items=self.all_items)
else:
# key and child share a common prefix:
#
# -[new placeholder]
# / \
# [new child] [child]
#
self.check_consistent_insert(key - common)
del self.children[child_key]
self.children[common] = SetTrie(
children={
frozenset(child_key - common): child,
frozenset(key - common): SetTrie(all_items=self.all_items)},
all_items=self.all_items)
self.all_items.update(key - common)
def get_visual_iname_order_embedding(kernel):
"""
Return :class:`dict` `embedding` mapping inames to a totally ordered set of
values, such that `embedding[iname1] < embedding[iname2]` when `iname2`
is nested inside `iname1`.
"""
from loopy.kernel.data import IlpBaseTag
# Ignore ILP tagged inames, since they do not have to form a strict loop
# nest.
ilp_inames = frozenset(
iname for iname in kernel.iname_to_tag
if isinstance(kernel.iname_to_tag[iname], IlpBaseTag))
iname_trie = SetTrie()
for insn in kernel.instructions:
within_inames = set(
iname for iname in insn.within_inames
if iname not in ilp_inames)
iname_trie.add_or_update(within_inames)
embedding = {}
def update_embedding(inames):
embedding.update(
dict((iname, (len(embedding), iname)) for iname in inames))
iname_trie.descend(update_embedding)
for iname in ilp_inames:
# Nest ilp_inames innermost, so they don't interrupt visual order.
embedding[iname] = (len(embedding), iname)
return embedding
# }}}
# {{{ find_recursive_dependencies
def find_recursive_dependencies(kernel, insn_ids):
......
......@@ -23,6 +23,7 @@ THE SOFTWARE.
"""
import six # noqa
import pytest
from six.moves import range
import sys
......@@ -75,6 +76,22 @@ def test_compute_sccs():
verify_sccs(graph, compute_sccs(graph))
def test_SetTrie():
from loopy.kernel.tools import SetTrie
s = SetTrie()
s.add_or_update(set([1, 2, 3]))
s.add_or_update(set([4, 2, 1]))
s.add_or_update(set([1, 5]))
result = []
s.descend(lambda prefix: result.extend(prefix))
assert result == [1, 2, 3, 4, 5]
with pytest.raises(ValueError):
s.add_or_update(set([1, 4]))
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment