diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 226fb45400283c240ecf172ed44da61fce14b0ee..1496bf84cc178bbe1f8512b74c2364667b341c96 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -1075,6 +1075,21 @@ class LoopKernel(ImmutableRecordWithoutPickling): # {{{ pretty-printing + @memoize_method + def _get_iname_order_for_printing(self): + try: + from loopy.kernel.tools import get_visual_iname_order_embedding + embedding = get_visual_iname_order_embedding(self) + except ValueError: + from loopy.diagnostic import warn_with_kernel + warn_with_kernel(self, + "iname-order", + "get_visual_iname_order_embedding() could not determine a " + "consistent iname nesting order") + embedding = dict((iname, iname) for iname in self.all_inames()) + + return embedding + def stringify(self, what=None, with_dependencies=False): all_what = set([ "name", @@ -1228,7 +1243,9 @@ class LoopKernel(ImmutableRecordWithoutPickling): raise LoopyError("unexpected instruction type: %s" % type(insn).__name__) - loop_list = ",".join(natsorted(kernel.insn_inames(insn))) + order = self._get_iname_order_for_printing() + loop_list = ",".join( + sorted(kernel.insn_inames(insn), key=lambda iname: order[iname])) options = [Fore.GREEN+insn.id+Style.RESET_ALL] if insn.priority: diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 85a8da936469e97577af742bf39286acb0188206..af64bf977600b677669eac7fe33f87febee78dcc 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1100,6 +1100,120 @@ def guess_var_shape(kernel, var_name): # }}} +# {{{ loop nest tracker + +class SetTrie(object): + """ + Similar to a trie, but uses an unordered sequence as the key. + """ + + def __init__(self, children=(), all_items=None): + self.children = dict(children) + # all_items should be shared within a trie. + if all_items is None: + self.all_items = set() + else: + self.all_items = all_items + + def descend(self, on_found=lambda prefix: None, prefix=frozenset()): + on_found(prefix) + from six import iteritems + for prefix, child in sorted( + iteritems(self.children), + key=lambda it: sorted(it[0])): + child.descend(on_found, prefix=prefix) + + def check_consistent_insert(self, items_to_insert): + if items_to_insert & self.all_items: + raise ValueError("inconsistent nesting") + + def add_or_update(self, key): + if len(key) == 0: + return + + from six import iteritems + + for child_key, child in iteritems(self.children): + common = child_key & key + if common: + break + else: + # Key not found - insert new child + self.check_consistent_insert(key) + self.children[frozenset(key)] = SetTrie(all_items=self.all_items) + self.all_items.update(key) + return + + if child_key <= key: + # child is a prefix of key: + child.add_or_update(key - common) + elif key < child_key: + # key is a strict prefix of child: + # + # -[new child] + # | + # [child] + # + del self.children[child_key] + self.children[common] = SetTrie( + children={frozenset(child_key - common): child}, + all_items=self.all_items) + else: + # key and child share a common prefix: + # + # -[new placeholder] + # / \ + # [new child] [child] + # + self.check_consistent_insert(key - common) + + del self.children[child_key] + self.children[common] = SetTrie( + children={ + frozenset(child_key - common): child, + frozenset(key - common): SetTrie(all_items=self.all_items)}, + all_items=self.all_items) + self.all_items.update(key - common) + + +def get_visual_iname_order_embedding(kernel): + """ + Return :class:`dict` `embedding` mapping inames to a totally ordered set of + values, such that `embedding[iname1] < embedding[iname2]` when `iname2` + is nested inside `iname1`. + """ + from loopy.kernel.data import IlpBaseTag + # Ignore ILP tagged inames, since they do not have to form a strict loop + # nest. + ilp_inames = frozenset( + iname for iname in kernel.iname_to_tag + if isinstance(kernel.iname_to_tag[iname], IlpBaseTag)) + + iname_trie = SetTrie() + + for insn in kernel.instructions: + within_inames = set( + iname for iname in insn.within_inames + if iname not in ilp_inames) + iname_trie.add_or_update(within_inames) + + embedding = {} + + def update_embedding(inames): + embedding.update( + dict((iname, (len(embedding), iname)) for iname in inames)) + + iname_trie.descend(update_embedding) + + for iname in ilp_inames: + # Nest ilp_inames innermost, so they don't interrupt visual order. + embedding[iname] = (len(embedding), iname) + + return embedding + +# }}} + + # {{{ find_recursive_dependencies def find_recursive_dependencies(kernel, insn_ids): diff --git a/test/test_misc.py b/test/test_misc.py index 94d83cd29c6a89a7d01cc23cca2a2d9f1985eba9..ca4aee5b816eee3eae491c3eecdad296ac04323f 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -23,6 +23,7 @@ THE SOFTWARE. """ import six # noqa +import pytest from six.moves import range import sys @@ -75,6 +76,22 @@ def test_compute_sccs(): verify_sccs(graph, compute_sccs(graph)) +def test_SetTrie(): + from loopy.kernel.tools import SetTrie + + s = SetTrie() + s.add_or_update(set([1, 2, 3])) + s.add_or_update(set([4, 2, 1])) + s.add_or_update(set([1, 5])) + + result = [] + s.descend(lambda prefix: result.extend(prefix)) + assert result == [1, 2, 3, 4, 5] + + with pytest.raises(ValueError): + s.add_or_update(set([1, 4])) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])