diff --git a/sumpy/toys.py b/sumpy/toys.py index 4641d6976318fa3d8472e096d7a79b1d597c8219..08330406e6f00f4fa3a0a60dbb4f37cd8142424c 100644 --- a/sumpy/toys.py +++ b/sumpy/toys.py @@ -120,7 +120,7 @@ class ToyContext(object): # {{{ helpers -def _p2e(psource, center, order, p2e, expn_class): +def _p2e(psource, center, order, p2e, expn_class, expn_kwargs): source_boxes = np.array([0], dtype=np.int32) box_source_starts = np.array([0], dtype=np.int32) box_source_counts_nonchild = np.array( @@ -144,7 +144,8 @@ def _p2e(psource, center, order, p2e, expn_class): #flags="print_hl_cl", out_host=True, **toy_ctx.extra_source_kwargs) - return expn_class(toy_ctx, center, order, coeffs[0]) + return expn_class(toy_ctx, center, order, coeffs[0], derived_from=psource, + **expn_kwargs) def _e2p(psource, targets, e2p): @@ -175,7 +176,7 @@ def _e2p(psource, targets, e2p): return pot -def _e2e(psource, to_center, to_order, e2e, expn_class): +def _e2e(psource, to_center, to_order, e2e, expn_class, expn_kwargs): toy_ctx = psource.toy_ctx target_boxes = np.array([1], dtype=np.int32) @@ -209,7 +210,8 @@ def _e2e(psource, to_center, to_order, e2e, expn_class): #flags="print_hl_cl", out_host=True, **toy_ctx.extra_source_kwargs) - return expn_class(toy_ctx, to_center, to_order, to_coeffs[1]) + return expn_class(toy_ctx, to_center, to_order, to_coeffs[1], + derived_from=psource, **expn_kwargs) # }}} @@ -282,11 +284,12 @@ class PointSources(PotentialSource): ``[ndim, npoints]`` """ - def __init__(self, toy_ctx, points, weights): + def __init__(self, toy_ctx, points, weights, center=None): super(PointSources, self).__init__(toy_ctx) self.points = points self.weights = weights + self._center = center def eval(self, targets): evt, (potential,) = self.toy_ctx.get_p2p()( @@ -295,14 +298,31 @@ class PointSources(PotentialSource): return potential + @property + def center(self): + if self._center is not None: + return self._center + + return np.average(self.points, axis=1) + class ExpansionPotentialSource(PotentialSource): - def __init__(self, toy_ctx, center, order, coeffs): + """ + .. attribute:: radius + + Not used mathematically. Just for visualization, purely advisory. + """ + def __init__(self, toy_ctx, center, order, coeffs, derived_from, + radius=None, expn_style=None): super(ExpansionPotentialSource, self).__init__(toy_ctx) self.center = np.asarray(center) self.order = order self.coeffs = coeffs + self.derived_from = derived_from + self.radius = radius + self.expn_style = expn_style + class MultipoleExpansion(ExpansionPotentialSource): def eval(self, targets): @@ -322,6 +342,16 @@ class PotentialExpressionNode(PotentialSource): self.psources = psources + @property + def center(self): + for psource in self.psources: + try: + return psource.center + except AttributeError: + pass + + raise ValueError("no psource with a center found") + class Sum(PotentialExpressionNode): def eval(self, targets): @@ -343,13 +373,13 @@ class Product(PotentialExpressionNode): # }}} -def multipole_expand(psource, center, order=None): +def multipole_expand(psource, center, order=None, **expn_kwargs): if isinstance(psource, PointSources): if order is None: raise ValueError("order may not be None") return _p2e(psource, center, order, psource.toy_ctx.get_p2m(order), - MultipoleExpansion) + MultipoleExpansion, expn_kwargs) elif isinstance(psource, MultipoleExpansion): if order is None: @@ -357,20 +387,20 @@ def multipole_expand(psource, center, order=None): return _e2e(psource, center, order, psource.toy_ctx.get_m2m(psource.order, order), - MultipoleExpansion) + MultipoleExpansion, expn_kwargs) else: raise TypeError("do not know how to expand '%s'" % type(psource).__name__) -def local_expand(psource, center, order=None): +def local_expand(psource, center, order=None, **expn_kwargs): if isinstance(psource, PointSources): if order is None: raise ValueError("order may not be None") return _p2e(psource, center, order, psource.toy_ctx.get_p2l(order), - LocalExpansion) + LocalExpansion, expn_kwargs) elif isinstance(psource, MultipoleExpansion): if order is None: @@ -378,7 +408,7 @@ def local_expand(psource, center, order=None): return _e2e(psource, center, order, psource.toy_ctx.get_m2l(psource.order, order), - LocalExpansion) + LocalExpansion, expn_kwargs) elif isinstance(psource, LocalExpansion): if order is None: @@ -386,7 +416,7 @@ def local_expand(psource, center, order=None): return _e2e(psource, center, order, psource.toy_ctx.get_l2l(psource.order, order), - LocalExpansion) + LocalExpansion, expn_kwargs) else: raise TypeError("do not know how to expand '%s'" @@ -432,12 +462,9 @@ def l_inf(psource, radius, center=None, npoints=100, debug=False): return np.max(np.abs(z)) -def draw_box(center, radius, **kwargs): - center = np.asarray(center) - - el = center - radius - eh = center + radius +# {{{ schematic visualization +def draw_box(el, eh, **kwargs): import matplotlib.pyplot as pt import matplotlib.patches as mpatches from matplotlib.path import Path @@ -462,4 +489,66 @@ def draw_circle(center, radius, **kwargs): import matplotlib.pyplot as plt plt.gca().add_patch(plt.Circle((center[0], center[1]), radius, **kwargs)) + +def draw_arrow(from_pt, to_pt, shorten=0, **kwargs): + import matplotlib.pyplot as plt + dist = to_pt - from_pt + + from_pt = from_pt + shorten*dist + dist = dist - 2*shorten*dist + plt.arrow(from_pt[0], from_pt[1], dist[0], dist[1], **kwargs) + + +class SchematicVisitor(object): + def __init__(self, default_expn_style="circle"): + self.default_expn_style = default_expn_style + + def rec(self, psource): + getattr(self, "visit_"+type(psource).__name__.lower())(psource) + + def visit_pointsources(self, psource): + import matplotlib.pyplot as plt + plt.plot(psource.points[0], psource.points[1], "o") + + def visit_sum(self, psource): + for ps in psource.psources: + self.rec(ps) + + visit_product = visit_sum + + def visit_multipoleexpansion(self, psource): + expn_style = self.default_expn_style + if psource.expn_style is not None: + expn_style = psource.expn_style + + if psource.radius is not None: + if expn_style == "box": + r2 = psource.radius / np.sqrt(2) + draw_box(psource.center - r2, psource.center + r2, fill=None) + elif expn_style == "circle": + draw_circle(psource.center, psource.radius, fill=None) + else: + raise ValueError("unknown expn_style: %s" % self.expn_style) + + import matplotlib.pyplot as plt + plt.gca().text(psource.center[0], psource.center[1], + type(psource).__name__[0], + verticalalignment='center', horizontalalignment='center') + + if psource.derived_from is not None: + draw_arrow(psource.derived_from.center, psource.center, shorten=0.1, + facecolor="black", length_includes_head=True, width=0.02) + self.rec(psource.derived_from) + + visit_localexpansion = visit_multipoleexpansion + + +def draw_schematic(psource, **kwargs): + SchematicVisitor(**kwargs).rec(psource) + import matplotlib.pyplot as plt + plt.gca().set_aspect("equal") + plt.tight_layout() + +# }}} + # vim: foldmethod=marker