diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 8129757ca41c68300960b7f61f1045e19d3a7003..2016ba0c76c37179a71646b92c30b5dc6fa4d334 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -38,6 +38,8 @@ from pymbolic.mapper import ( WalkMapper as WalkMapperBase, CachingMapperMixin ) +from pymbolic.mapper.graphviz import ( + GraphvizMapper as GraphvizMapperBase) from pymbolic.mapper.stringifier import ( StringifyMapper as StringifyMapperBase, PREC_NONE @@ -72,15 +74,18 @@ class Collector(CollectorBase): class WalkMapper(WalkMapperBase): def map_nabla(self, expr, *args): self.visit(expr, *args) + self.post_visit(expr) def map_nabla_component(self, expr, *args): self.visit(expr, *args) + self.post_visit(expr) def map_derivative_source(self, expr, *args): if not self.visit(expr, *args): return self.rec(expr.operand) + self.post_visit(expr) class EvaluationMapper(EvaluationMapperBase): @@ -118,6 +123,18 @@ class StringifyMapper(StringifyMapperBase): return r"D[%s](%s)" % (expr.nabla_id, self.rec(expr.operand, PREC_NONE)) +class GraphvizMapper(GraphvizMapperBase): + def map_derivative_source(self, expr): + self.lines.append( + "%s [label=\"D[%s]\",shape=ellipse];" % ( + self.get_id(expr), expr.nabla_id)) + if not self.visit(expr, node_printed=True): + return + + self.rec(expr.operand) + self.post_visit(expr) + + # {{{ dimensionalizer class Dimensionalizer(EvaluationMapper): diff --git a/pymbolic/mapper/graphviz.py b/pymbolic/mapper/graphviz.py index f2f2ce471c69d1f83d47449a07ebdd49aa0cffd0..0ee50a53350efbf5d263195a44abd7436b831a62 100644 --- a/pymbolic/mapper/graphviz.py +++ b/pymbolic/mapper/graphviz.py @@ -39,6 +39,8 @@ class GraphvizMapper(WalkMapper): self.parent_stack = [] self.next_unique_id = -1 + self.nodes_visited = set() + self.common_subexpressions = {} def get_dot_code(self): """Return the dot source code for a previously traversed expression.""" @@ -51,25 +53,41 @@ class GraphvizMapper(WalkMapper): return "id%d" % id(expr) + def map_leaf(self, expr): + self.lines.append( + "%s [label=\"%s\", shape=box];" % ( + self.get_id(expr), str(expr))) + + self.visit(expr, node_printed=True) + self.post_visit(expr) + def generate_unique_id(self): self.next_unique_id += 1 return "uid%d" % self.next_unique_id def visit(self, expr, node_printed=False, node_id=None): + # {{{ print connectivity + if node_id is None: node_id = self.get_id(expr) + if self.parent_stack: + self.lines.append("%s -> %s;" % ( + self.get_id(self.parent_stack[-1]), + node_id)) + + # }}} + + if id(expr) in self.nodes_visited: + return False + self.nodes_visited.add(id(expr)) + if not node_printed: self.lines.append( "%s [label=\"%s\"];" % ( self.get_id(expr), type(expr).__name__)) - if self.parent_stack: - self.lines.append("%s -> %s;" % ( - self.get_id(self.parent_stack[-1]), - node_id)) - self.parent_stack.append(expr) return True @@ -102,7 +120,17 @@ class GraphvizMapper(WalkMapper): self.lines.append( "%s [label=\"%s\", shape=box];" % (self.get_id(expr), expr.name)) - self.visit(expr, node_printed=True) + if self.visit(expr, node_printed=True): + self.post_visit(expr) + + def map_lookup(self, expr): + self.lines.append( + "%s [label=\"Lookup[%s]\",shape=box];" % ( + self.get_id(expr), expr.name)) + if not self.visit(expr, node_printed=True): + return + + self.rec(expr.aggregate) self.post_visit(expr) def map_constant(self, expr): @@ -116,5 +144,43 @@ class GraphvizMapper(WalkMapper): "%s [label=\"%s\",shape=ellipse];" % ( node_id, str(expr))) - self.visit(expr, node_printed=True, node_id=node_id) + if not self.visit(expr, node_printed=True, node_id=node_id): + return + + self.post_visit(expr) + + def map_call(self, expr): + from pymbolic.primitives import Variable + if not isinstance(expr.function, Variable): + return super(GraphvizMapper, self).map_call(expr) + + self.lines.append( + "%s [label=\"Call[%s]\",shape=box];" % ( + self.get_id(expr), str(expr.function))) + if not self.visit(expr, node_printed=True): + return + + for child in expr.parameters: + self.rec(child) + self.post_visit(expr) + + def map_common_subexpression(self, expr): + try: + expr = self.common_subexpressions[expr] + except KeyError: + self.common_subexpressions[expr] = expr + + if not self.visit(expr): + return + + self.rec(expr.child) + + self.post_visit(expr) + + # {{{ geometric algebra + + map_nabla_component = map_leaf + map_nabla = map_leaf + + # }}}