diff --git a/pymbolic/imperative/utils.py b/pymbolic/imperative/utils.py
index 5221859af1a5349e3cccaf0e8b31ff0d64c3cadc..39d0d138257610acf0ad18e17af8c38f9807e84c 100644
--- a/pymbolic/imperative/utils.py
+++ b/pymbolic/imperative/utils.py
@@ -34,40 +34,66 @@ logger = logging.getLogger(__name__)
 
 # {{{ graphviz / dot export
 
-def get_dot_dependency_graph(instructions, use_insn_ids=False,
-        addtional_lines_hook=None):
+def _default_preamble_hook():
+    # Sets default attributes for nodes and edges.
+    yield "node [shape=\"box\"];"
+    yield "edge [dir=\"back\"];"
+
+
+def get_dot_dependency_graph(
+        statements,  use_stmt_ids=None,
+        preamble_hook=_default_preamble_hook,
+        additional_lines_hook=list,
+
+        # deprecated
+        use_insn_ids=None,):
     """Return a string in the `dot <http://graphviz.org/>`_ language depicting
-    dependencies among kernel instructions.
+    dependencies among kernel statements.
+
+    :arg preamble_hook: A function that returns an iterable of lines
+        to add at the beginning of the graph
+    :arg additional_lines_hook: A function that returns an iterable
+        of lines to add at the end of the graph
     """
 
-    lines = []
+    if use_stmt_ids is not None and use_insn_ids is not None:
+        raise TypeError("may not specify both use_stmt_ids and use_insn_ids")
+
+    if use_insn_ids is not None:
+        use_stmt_ids = use_insn_ids
+        from warnings import warn
+        warn("'use_insn_ids' is deprecated. Use 'use_stmt_ids' instead.",
+                DeprecationWarning, stacklevel=2)
+
+    def get_node_attrs(stmt):
+        if use_stmt_ids:
+            stmt_label = stmt.id
+            tooltip = str(stmt)
+        else:
+            stmt_label = str(stmt)
+            tooltip = stmt.id
+
+        return "label=\"%s\",shape=\"box\",tooltip=\"%s\"" % (
+                repr(stmt_label)[1:-1],
+                repr(tooltip)[1:-1],
+                )
+
+    lines = list(preamble_hook())
     dep_graph = {}
 
     # maps (oriented) edge onto annotation string
     annotation_dep_graph = {}
 
-    for insn in instructions:
-        if use_insn_ids:
-            insn_label = insn.id
-            tooltip = str(insn)
-        else:
-            insn_label = str(insn)
-            tooltip = insn.id
-
-        lines.append("\"%s\" [label=\"%s\",shape=\"box\",tooltip=\"%s\"];"
-                % (
-                    insn.id,
-                    repr(insn_label)[1:-1],
-                    repr(tooltip)[1:-1],
-                    ))
-        for dep in insn.depends_on:
-            dep_graph.setdefault(insn.id, set()).add(dep)
+    for stmt in statements:
+        lines.append("\"%s\" [%s];" % (stmt.id, get_node_attrs(stmt)))
+        for dep in stmt.depends_on:
+            dep_graph.setdefault(stmt.id, set()).add(dep)
 
         if 0:
-            for dep in insn.then_depends_on:
-                annotation_dep_graph[(insn.id, dep)] = "then"
-            for dep in insn.else_depends_on:
-                annotation_dep_graph[(insn.id, dep)] = "else"
+            for dep in stmt.then_depends_on:
+                annotation_dep_graph[(stmt.id, dep)] = "then"
+            for dep in stmt.else_depends_on:
+                annotation_dep_graph[(stmt.id, dep)] = "else"
 
     # {{{ O(n^3) (i.e. slow) transitive reduction
 
@@ -75,35 +101,34 @@ def get_dot_dependency_graph(instructions, use_insn_ids=False,
     while True:
         changed_something = False
 
-        for insn_1 in dep_graph:
-            for insn_2 in dep_graph.get(insn_1, set()).copy():
-                for insn_3 in dep_graph.get(insn_2, set()).copy():
-                    if insn_3 not in dep_graph.get(insn_1, set()):
+        for stmt_1 in dep_graph:
+            for stmt_2 in dep_graph.get(stmt_1, set()).copy():
+                for stmt_3 in dep_graph.get(stmt_2, set()).copy():
+                    if stmt_3 not in dep_graph.get(stmt_1, set()):
                         changed_something = True
-                        dep_graph[insn_1].add(insn_3)
+                        dep_graph[stmt_1].add(stmt_3)
 
         if not changed_something:
             break
 
-    for insn_1 in dep_graph:
-        for insn_2 in dep_graph.get(insn_1, set()).copy():
-            for insn_3 in dep_graph.get(insn_2, set()).copy():
-                if insn_3 in dep_graph.get(insn_1, set()):
-                    dep_graph[insn_1].remove(insn_3)
+    for stmt_1 in dep_graph:
+        for stmt_2 in dep_graph.get(stmt_1, set()).copy():
+            for stmt_3 in dep_graph.get(stmt_2, set()).copy():
+                if stmt_3 in dep_graph.get(stmt_1, set()):
+                    dep_graph[stmt_1].remove(stmt_3)
 
     # }}}
 
-    for insn_1 in dep_graph:
-        for insn_2 in dep_graph.get(insn_1, set()):
-            lines.append("%s -> %s" % (insn_2, insn_1))
+    for stmt_1 in dep_graph:
+        for stmt_2 in dep_graph.get(stmt_1, set()):
+            lines.append("%s -> %s" % (stmt_2, stmt_1))
 
-    for (insn_1, insn_2), annot in six.iteritems(annotation_dep_graph):
+    for (stmt_1, stmt_2), annot in six.iteritems(annotation_dep_graph):
             lines.append(
-                    "%s -> %s  [label=\"%s\", style=dashed]"
-                    % (insn_2, insn_1, annot))
+                    "%s -> %s  [label=\"%s\",style=\"dashed\"]"
+                    % (stmt_2, stmt_1, annot))
 
-    if addtional_lines_hook is not None:
-        lines.extend(addtional_lines_hook())
+    lines.extend(additional_lines_hook())
 
     return "digraph code {\n%s\n}" % (
             "\n".join(lines)