diff --git a/examples/expansion-toys.py b/examples/expansion-toys.py index 5c0714b978455bd6e205732faf63875c45cda092..72be2c7ec70b4af7014b0a3f268842d17b3a300f 100644 --- a/examples/expansion-toys.py +++ b/examples/expansion-toys.py @@ -21,8 +21,13 @@ def main(): plt.colorbar() plt.show() - lexp = t.multipole_expand(pt_src, [0, 0], 5) + mexp = t.multipole_expand(pt_src, [0, 0], 9) + mexp2 = t.multipole_expand(mexp, [0, 0.25]) + lexp = t.local_expand(mexp, [3, 0]) + lexp2 = t.local_expand(lexp, [3, 1]) + diff = mexp - pt_src + diff = mexp2 - pt_src diff = lexp - pt_src if 1: diff --git a/sumpy/toys.py b/sumpy/toys.py index 6b0499734fb3b61089ac7f84591aab5e505d6d55..f61a90903e963f581d26d2dd2f55b68b2ac710b8 100644 --- a/sumpy/toys.py +++ b/sumpy/toys.py @@ -94,6 +94,27 @@ class ToyContext(object): self.local_expn_class(self.kernel, order), [self.kernel]) + @memoize_method + def get_m2m(self, from_order, to_order): + from sumpy import E2EFromCSR + return E2EFromCSR(self.cl_context, + self.mpole_expn_class(self.kernel, from_order), + self.mpole_expn_class(self.kernel, to_order)) + + @memoize_method + def get_m2l(self, from_order, to_order): + from sumpy import E2EFromCSR + return E2EFromCSR(self.cl_context, + self.mpole_expn_class(self.kernel, from_order), + self.local_expn_class(self.kernel, to_order)) + + @memoize_method + def get_l2l(self, from_order, to_order): + from sumpy import E2EFromCSR + return E2EFromCSR(self.cl_context, + self.local_expn_class(self.kernel, from_order), + self.local_expn_class(self.kernel, to_order)) + # }}} @@ -153,9 +174,48 @@ def _e2p(psource, targets, e2p): return pot + +def _e2e(psource, to_center, to_order, e2e, expn_class): + toy_ctx = psource.toy_ctx + + target_boxes = np.array([1], dtype=np.int32) + src_box_starts = np.array([0, 1], dtype=np.int32) + src_box_lists = np.array([0], dtype=np.int32) + + centers = (np.array( + [ + # box 0: source + psource.center, + + # box 1: target + to_center, + ], + dtype=np.float64)).T.copy() + + coeffs = np.array([psource.coeffs]) + + evt, (to_coeffs,) = e2e( + toy_ctx.queue, + src_expansions=coeffs, + src_base_ibox=0, + tgt_base_ibox=0, + ntgt_level_boxes=2, + + target_boxes=target_boxes, + + src_box_starts=src_box_starts, + src_box_lists=src_box_lists, + centers=centers, + #flags="print_hl_cl", + out_host=True, **toy_ctx.extra_source_kwargs) + + return expn_class(toy_ctx, to_center, to_order, to_coeffs[1]) + # }}} +# {{{ potential source classes + class PotentialSource(object): def __init__(self, toy_ctx): self.toy_ctx = toy_ctx @@ -254,32 +314,6 @@ class LocalExpansion(ExpansionPotentialSource): return _e2p(self, targets, self.toy_ctx.get_l2p(self.order)) -def multipole_expand(psource, center, order): - if isinstance(psource, PointSources): - return _p2e(psource, center, order, psource.toy_ctx.get_p2m(order), - MultipoleExpansion) - - elif isinstance(psource, MultipoleExpansion): - raise NotImplementedError() - else: - raise TypeError("do not know how to expand '%s'" - % type(psource).__name__) - - -def local_expand(psource, center, order): - if isinstance(psource, PointSources): - return _p2e(psource, center, order, psource.toy_ctx.get_p2l(order), - LocalExpansion) - - elif isinstance(psource, MultipoleExpansion): - raise NotImplementedError() - elif isinstance(psource, LocalExpansion): - raise NotImplementedError() - else: - raise TypeError("do not know how to expand '%s'" - % type(psource).__name__) - - class PotentialExpressionNode(PotentialSource): def __init__(self, psources): from pytools import single_valued @@ -306,6 +340,58 @@ class Product(PotentialExpressionNode): return result +# }}} + + +def multipole_expand(psource, center, order=None): + if isinstance(psource, PointSources): + if order is None: + raise ValueError("order may not be None") + + return _p2e(psource, center, order, psource.toy_ctx.get_p2m(order), + MultipoleExpansion) + + elif isinstance(psource, MultipoleExpansion): + if order is None: + order = psource.order + + return _e2e(psource, center, order, + psource.toy_ctx.get_m2m(psource.order, order), + MultipoleExpansion) + + else: + raise TypeError("do not know how to expand '%s'" + % type(psource).__name__) + + +def local_expand(psource, center, order=None): + if isinstance(psource, PointSources): + if order is None: + raise ValueError("order may not be None") + + return _p2e(psource, center, order, psource.toy_ctx.get_p2l(order), + LocalExpansion) + + elif isinstance(psource, MultipoleExpansion): + if order is None: + order = psource.order + + return _e2e(psource, center, order, + psource.toy_ctx.get_m2l(psource.order, order), + LocalExpansion) + + elif isinstance(psource, LocalExpansion): + if order is None: + order = psource.order + + return _e2e(psource, center, order, + psource.toy_ctx.get_l2l(psource.order, order), + LocalExpansion) + + else: + raise TypeError("do not know how to expand '%s'" + % type(psource).__name__) + def logplot(fp, psource, **kwargs): fp.show_scalar_in_matplotlib(