From d5663834d5371f68ef0bbca07ba85be3d554e33a Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Tue, 24 Jan 2017 13:14:02 -0600
Subject: [PATCH] Kernel printing: Improve iname ordering in output.

This adds a tool for computing the iname nest efficiently, for printing.

Closes #5
---
 loopy/kernel/__init__.py |  19 ++++++-
 loopy/kernel/tools.py    | 114 +++++++++++++++++++++++++++++++++++++++
 test/test_misc.py        |  17 ++++++
 3 files changed, 149 insertions(+), 1 deletion(-)

diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 71b112775..008c0c18a 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -1075,6 +1075,21 @@ class LoopKernel(ImmutableRecordWithoutPickling):
 
     # {{{ pretty-printing
 
+    @memoize_method
+    def _get_iname_order_for_printing(self):
+        try:
+            from loopy.kernel.tools import get_visual_iname_order_embedding
+            embedding = get_visual_iname_order_embedding(self)
+        except ValueError:
+            from loopy.diagnostic import warn_with_kernel
+            warn_with_kernel(self,
+                "iname-order",
+                "LoopNestTracker could not determine a consistent iname "
+                "nesting order")
+            embedding = dict((iname, iname) for iname in self.all_inames())
+
+        return embedding
+
     def stringify(self, what=None, with_dependencies=False):
         all_what = set([
             "name",
@@ -1214,7 +1229,9 @@ class LoopKernel(ImmutableRecordWithoutPickling):
                     raise LoopyError("unexpected instruction type: %s"
                             % type(insn).__name__)
 
-                loop_list = ",".join(sorted(kernel.insn_inames(insn)))
+                order = self._get_iname_order_for_printing()
+                loop_list = ",".join(
+                    sorted(kernel.insn_inames(insn), key=lambda iname: order[iname]))
 
                 options = [Fore.GREEN+insn.id+Style.RESET_ALL]
                 if insn.priority:
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index 85a8da936..af64bf977 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -1100,6 +1100,120 @@ def guess_var_shape(kernel, var_name):
 # }}}
 
 
+# {{{ loop nest tracker
+
+class SetTrie(object):
+    """
+    Similar to a trie, but uses an unordered sequence as the key.
+    """
+
+    def __init__(self, children=(), all_items=None):
+        self.children = dict(children)
+        # all_items should be shared within a trie.
+        if all_items is None:
+            self.all_items = set()
+        else:
+            self.all_items = all_items
+
+    def descend(self, on_found=lambda prefix: None, prefix=frozenset()):
+        on_found(prefix)
+        from six import iteritems
+        for prefix, child in sorted(
+                iteritems(self.children),
+                key=lambda it: sorted(it[0])):
+            child.descend(on_found, prefix=prefix)
+
+    def check_consistent_insert(self, items_to_insert):
+        if items_to_insert & self.all_items:
+            raise ValueError("inconsistent nesting")
+
+    def add_or_update(self, key):
+        if len(key) == 0:
+            return
+
+        from six import iteritems
+
+        for child_key, child in iteritems(self.children):
+            common = child_key & key
+            if common:
+                break
+        else:
+            # Key not found - insert new child
+            self.check_consistent_insert(key)
+            self.children[frozenset(key)] = SetTrie(all_items=self.all_items)
+            self.all_items.update(key)
+            return
+
+        if child_key <= key:
+            # child is a prefix of key:
+            child.add_or_update(key - common)
+        elif key < child_key:
+            # key is a strict prefix of child:
+            #
+            #  -[new child]
+            #     |
+            #   [child]
+            #
+            del self.children[child_key]
+            self.children[common] = SetTrie(
+                children={frozenset(child_key - common): child},
+                all_items=self.all_items)
+        else:
+            # key and child share a common prefix:
+            #
+            # -[new placeholder]
+            #      /        \
+            #  [new child]   [child]
+            #
+            self.check_consistent_insert(key - common)
+
+            del self.children[child_key]
+            self.children[common] = SetTrie(
+                children={
+                    frozenset(child_key - common): child,
+                    frozenset(key - common): SetTrie(all_items=self.all_items)},
+                all_items=self.all_items)
+            self.all_items.update(key - common)
+
+
+def get_visual_iname_order_embedding(kernel):
+    """
+    Return :class:`dict` `embedding` mapping inames to a totally ordered set of
+    values, such that `embedding[iname1] < embedding[iname2]` when `iname2`
+    is nested inside `iname1`.
+    """
+    from loopy.kernel.data import IlpBaseTag
+    # Ignore ILP tagged inames, since they do not have to form a strict loop
+    # nest.
+    ilp_inames = frozenset(
+        iname for iname in kernel.iname_to_tag
+        if isinstance(kernel.iname_to_tag[iname], IlpBaseTag))
+
+    iname_trie = SetTrie()
+
+    for insn in kernel.instructions:
+        within_inames = set(
+            iname for iname in insn.within_inames
+            if iname not in ilp_inames)
+        iname_trie.add_or_update(within_inames)
+
+    embedding = {}
+
+    def update_embedding(inames):
+        embedding.update(
+            dict((iname, (len(embedding), iname)) for iname in inames))
+
+    iname_trie.descend(update_embedding)
+
+    for iname in ilp_inames:
+        # Nest ilp_inames innermost, so they don't interrupt visual order.
+        embedding[iname] = (len(embedding), iname)
+
+    return embedding
+
+# }}}
+
+
 # {{{ find_recursive_dependencies
 
 def find_recursive_dependencies(kernel, insn_ids):
diff --git a/test/test_misc.py b/test/test_misc.py
index 94d83cd29..ca4aee5b8 100644
--- a/test/test_misc.py
+++ b/test/test_misc.py
@@ -23,6 +23,7 @@ THE SOFTWARE.
 """
 
 import six  # noqa
+import pytest
 from six.moves import range
 
 import sys
@@ -75,6 +76,22 @@ def test_compute_sccs():
             verify_sccs(graph, compute_sccs(graph))
 
 
+def test_SetTrie():
+    from loopy.kernel.tools import SetTrie
+
+    s = SetTrie()
+    s.add_or_update(set([1, 2, 3]))
+    s.add_or_update(set([4, 2, 1]))
+    s.add_or_update(set([1, 5]))
+
+    result = []
+    s.descend(lambda prefix: result.extend(prefix))
+    assert result == [1, 2, 3, 4, 5]
+
+    with pytest.raises(ValueError):
+        s.add_or_update(set([1, 4]))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab