diff --git a/sumpy/toys.py b/sumpy/toys.py index d61af4faa80b1e40a7015f3b0c4db9ed4b00e231..9262782ab2695ddef4c812d73c7deb695a700cdd 100644 --- a/sumpy/toys.py +++ b/sumpy/toys.py @@ -494,6 +494,11 @@ def draw_circle(center, radius, **kwargs): plt.gca().add_patch(plt.Circle((center[0], center[1]), radius, **kwargs)) +def draw_point(loc, **kwargs): + import matplotlib.pyplot as plt + plt.plot(*loc, marker="o", **kwargs) + + def draw_arrow(from_pt, to_pt, shorten=0, **kwargs): import matplotlib.pyplot as plt dist = to_pt - from_pt @@ -503,6 +508,15 @@ def draw_arrow(from_pt, to_pt, shorten=0, **kwargs): plt.arrow(from_pt[0], from_pt[1], dist[0], dist[1], **kwargs) +def draw_annotation(from_pt, to_pt, label, **kwargs): + import matplotlib.pyplot as plt + + plt.gca().annotate(label, xy=from_pt, xytext=to_pt, + arrowprops=dict( + facecolor="white", edgecolor="white", shrink=0.05, + width=1, headwidth=5), color="white", **kwargs) + + class SchematicVisitor(object): def __init__(self, default_expn_style="circle"): self.default_expn_style = default_expn_style @@ -512,7 +526,7 @@ class SchematicVisitor(object): def visit_pointsources(self, psource): import matplotlib.pyplot as plt - plt.plot(psource.points[0], psource.points[1], "o") + plt.plot(psource.points[0], psource.points[1], "o", label="source") def visit_sum(self, psource): for ps in psource.psources: @@ -540,8 +554,11 @@ class SchematicVisitor(object): verticalalignment='center', horizontalalignment='center') if psource.derived_from is not None: + xmin, xmax = plt.xlim() + plt_width = xmax - xmin draw_arrow(psource.derived_from.center, psource.center, shorten=0.1, - facecolor="black", length_includes_head=True, width=0.02) + facecolor="black", length_includes_head=True, + width=0.0005 * plt_width) self.rec(psource.derived_from) visit_localexpansion = visit_multipoleexpansion