diff --git a/sumpy/toys.py b/sumpy/toys.py index 4f11fcbf487a39654fab3d047fe27f4f53484477..20ba4e0ed0ec5dc5e839051c6cf55267b7d900ae 100644 --- a/sumpy/toys.py +++ b/sumpy/toys.py @@ -338,8 +338,9 @@ class ExpansionPotentialSource(PotentialSource): .. attribute:: text_kwargs - Passed to :method:`matplotlib.axes.Axes.text`. Used for customizing the - expansion label. Just for visualization, purely advisory. + Passed to :method:`matplotlib.pyplot.annotate`. Used for customizing the + expansion label. Changing the label text is supported by passing the + kwarg *s*. Just for visualization, purely advisory. """ def __init__(self, toy_ctx, center, rscale, order, coeffs, derived_from, radius=None, expn_style=None, text_kwargs=None): @@ -526,28 +527,26 @@ def draw_point(loc, **kwargs): plt.plot(*loc, marker="o", **kwargs) -def draw_arrow(from_pt, to_pt, shorten=0, **kwargs): - import matplotlib.pyplot as plt - dist = to_pt - from_pt - - from_pt = from_pt + shorten*dist - dist = dist - 2*shorten*dist - plt.arrow(from_pt[0], from_pt[1], dist[0], dist[1], **kwargs) - +def draw_annotation(from_pt, to_pt, label, arrowprops={}, **kwargs): + """ + :arg from_pt: Tail of arrow + :arg to_pt: Head of arrow + :arg label: Annotation label + :arg arrowprops: Passed to arrowprops + :arg kwargs: Passed to annotate + """ -def draw_annotation(from_pt, to_pt, label, **kwargs): import matplotlib.pyplot as plt - color = kwargs.setdefault("color", "white") - arrowprops = dict( - facecolor=color, - edgecolor=color, - shrink=0.05, - width=1, - headwidth=5) + my_arrowprops = dict( + facecolor="black", + edgecolor="black", + arrowstyle="->") + + my_arrowprops.update(arrowprops) plt.gca().annotate(label, xy=from_pt, xytext=to_pt, - arrowprops=arrowprops, **kwargs) + arrowprops=my_arrowprops, **kwargs) class SchematicVisitor(object): @@ -581,26 +580,37 @@ class SchematicVisitor(object): else: raise ValueError("unknown expn_style: %s" % self.expn_style) + if psource.derived_from is None: + return + + # Draw an annotation of the form + # + # ------> M + text_kwargs = dict( - x=psource.center[0], - y=psource.center[1], - s=type(psource).__name__[0], verticalalignment="center", horizontalalignment="center") - if psource.text_kwargs is not None: - text_kwargs.update(psource.text_kwargs) + label = type(psource).__name__[0] - import matplotlib.pyplot as plt - plt.gca().text(**text_kwargs) - - if psource.derived_from is not None: - xmin, xmax = plt.xlim() - plt_width = xmax - xmin - draw_arrow(psource.derived_from.center, psource.center, shorten=0.1, - facecolor="black", length_includes_head=True, - width=0.0005 * plt_width) - self.rec(psource.derived_from) + if psource.text_kwargs is not None: + psource_text_kwargs_copy = psource.text_kwargs.copy() + label = psource_text_kwargs_copy.pop('s', label) + text_kwargs.update(psource_text_kwargs_copy) + + shrinkA = 0 # noqa + if isinstance(psource.derived_from, ExpansionPotentialSource): + # Avoid overlapping the tail of the arrow with any expansion labels that + # are present at the tail. + import matplotlib as mpl + font_size = mpl.rcParams['font.size'] + shrinkA = 2/3 * font_size # noqa + + arrowprops = dict(shrinkA=shrinkA, arrowstyle="<|-") # noqa + + draw_annotation(psource.center, psource.derived_from.center, label, + arrowprops, **text_kwargs) + self.rec(psource.derived_from) visit_localexpansion = visit_multipoleexpansion