diff --git a/pymbolic/imperative/utils.py b/pymbolic/imperative/utils.py index 5221859af1a5349e3cccaf0e8b31ff0d64c3cadc..39d0d138257610acf0ad18e17af8c38f9807e84c 100644 --- a/pymbolic/imperative/utils.py +++ b/pymbolic/imperative/utils.py @@ -34,40 +34,66 @@ logger = logging.getLogger(__name__) # {{{ graphviz / dot export -def get_dot_dependency_graph(instructions, use_insn_ids=False, - addtional_lines_hook=None): +def _default_preamble_hook(): + # Sets default attributes for nodes and edges. + yield "node [shape=\"box\"];" + yield "edge [dir=\"back\"];" + + +def get_dot_dependency_graph( + statements, use_stmt_ids=None, + preamble_hook=_default_preamble_hook, + additional_lines_hook=list, + + # deprecated + use_insn_ids=None,): """Return a string in the `dot <http://graphviz.org/>`_ language depicting - dependencies among kernel instructions. + dependencies among kernel statements. + + :arg preamble_hook: A function that returns an iterable of lines + to add at the beginning of the graph + :arg additional_lines_hook: A function that returns an iterable + of lines to add at the end of the graph """ - lines = [] + if use_stmt_ids is not None and use_insn_ids is not None: + raise TypeError("may not specify both use_stmt_ids and use_insn_ids") + + if use_insn_ids is not None: + use_stmt_ids = use_insn_ids + from warnings import warn + warn("'use_insn_ids' is deprecated. Use 'use_stmt_ids' instead.", + DeprecationWarning, stacklevel=2) + + def get_node_attrs(stmt): + if use_stmt_ids: + stmt_label = stmt.id + tooltip = str(stmt) + else: + stmt_label = str(stmt) + tooltip = stmt.id + + return "label=\"%s\",shape=\"box\",tooltip=\"%s\"" % ( + repr(stmt_label)[1:-1], + repr(tooltip)[1:-1], + ) + + lines = list(preamble_hook()) dep_graph = {} # maps (oriented) edge onto annotation string annotation_dep_graph = {} - for insn in instructions: - if use_insn_ids: - insn_label = insn.id - tooltip = str(insn) - else: - insn_label = str(insn) - tooltip = insn.id - - lines.append("\"%s\" [label=\"%s\",shape=\"box\",tooltip=\"%s\"];" - % ( - insn.id, - repr(insn_label)[1:-1], - repr(tooltip)[1:-1], - )) - for dep in insn.depends_on: - dep_graph.setdefault(insn.id, set()).add(dep) + for stmt in statements: + lines.append("\"%s\" [%s];" % (stmt.id, get_node_attrs(stmt))) + for dep in stmt.depends_on: + dep_graph.setdefault(stmt.id, set()).add(dep) if 0: - for dep in insn.then_depends_on: - annotation_dep_graph[(insn.id, dep)] = "then" - for dep in insn.else_depends_on: - annotation_dep_graph[(insn.id, dep)] = "else" + for dep in stmt.then_depends_on: + annotation_dep_graph[(stmt.id, dep)] = "then" + for dep in stmt.else_depends_on: + annotation_dep_graph[(stmt.id, dep)] = "else" # {{{ O(n^3) (i.e. slow) transitive reduction @@ -75,35 +101,34 @@ def get_dot_dependency_graph(instructions, use_insn_ids=False, while True: changed_something = False - for insn_1 in dep_graph: - for insn_2 in dep_graph.get(insn_1, set()).copy(): - for insn_3 in dep_graph.get(insn_2, set()).copy(): - if insn_3 not in dep_graph.get(insn_1, set()): + for stmt_1 in dep_graph: + for stmt_2 in dep_graph.get(stmt_1, set()).copy(): + for stmt_3 in dep_graph.get(stmt_2, set()).copy(): + if stmt_3 not in dep_graph.get(stmt_1, set()): changed_something = True - dep_graph[insn_1].add(insn_3) + dep_graph[stmt_1].add(stmt_3) if not changed_something: break - for insn_1 in dep_graph: - for insn_2 in dep_graph.get(insn_1, set()).copy(): - for insn_3 in dep_graph.get(insn_2, set()).copy(): - if insn_3 in dep_graph.get(insn_1, set()): - dep_graph[insn_1].remove(insn_3) + for stmt_1 in dep_graph: + for stmt_2 in dep_graph.get(stmt_1, set()).copy(): + for stmt_3 in dep_graph.get(stmt_2, set()).copy(): + if stmt_3 in dep_graph.get(stmt_1, set()): + dep_graph[stmt_1].remove(stmt_3) # }}} - for insn_1 in dep_graph: - for insn_2 in dep_graph.get(insn_1, set()): - lines.append("%s -> %s" % (insn_2, insn_1)) + for stmt_1 in dep_graph: + for stmt_2 in dep_graph.get(stmt_1, set()): + lines.append("%s -> %s" % (stmt_2, stmt_1)) - for (insn_1, insn_2), annot in six.iteritems(annotation_dep_graph): + for (stmt_1, stmt_2), annot in six.iteritems(annotation_dep_graph): lines.append( - "%s -> %s [label=\"%s\", style=dashed]" - % (insn_2, insn_1, annot)) + "%s -> %s [label=\"%s\",style=\"dashed\"]" + % (stmt_2, stmt_1, annot)) - if addtional_lines_hook is not None: - lines.extend(addtional_lines_hook()) + lines.extend(additional_lines_hook()) return "digraph code {\n%s\n}" % ( "\n".join(lines)