diff --git a/boxtree/pyfmmlib_integration.py b/boxtree/pyfmmlib_integration.py index d40fa8b64048aeb5fae8f32ec24064bd00735f7f..26493795d90c8861cab4675ce759f4ae674d6d67 100644 --- a/boxtree/pyfmmlib_integration.py +++ b/boxtree/pyfmmlib_integration.py @@ -45,8 +45,26 @@ class FMMLibExpansionWrangler(object): # {{{ constructor - def __init__(self, tree, helmholtz_k, nterms, ifgrad=False, - dipole_vec=None, dipoles_already_reordered=False): + def __init__(self, tree, helmholtz_k, fmm_level_to_nterms=None, ifgrad=False, + dipole_vec=None, dipoles_already_reordered=False, nterms=None): + """ + :arg fmm_level_to_nterms: a callable that, upon being passed the tree level + as an integer, returns the value of *nterms* for the multipole and + local expansions on that level. + """ + + if nterms is not None and fmm_level_to_nterms is not None: + raise TypeError("may specify either fmm_level_to_nterms or nterms, " + "but not both") + + if nterms is not None: + from warnings import warn + warn("Passing nterms is deprecated. Pass fmm_level_to_nterms instead.", + DeprecationWarning, stacklevel=2) + + def fmm_level_to_nterms(level): + return nterms + self.tree = tree if helmholtz_k == 0: @@ -56,26 +74,15 @@ class FMMLibExpansionWrangler(object): self.eqn_letter = "h" self.kernel_kwargs = {"zk": helmholtz_k} - self.nterms = nterms + self.level_nterms = np.array([ + fmm_level_to_nterms(lev) for lev in range(tree.nlevels) + ], dtype=np.int32) self.dtype = np.complex128 self.ifgrad = ifgrad self.dim = tree.dimensions - common_extra_kwargs = {} - if self.dim == 3 and self.eqn_letter == "h": - nquad = max(6, int(2.5*nterms)) - from pyfmmlib import legewhts - xnodes, weights = legewhts(nquad, ifwhts=1) - - common_extra_kwargs = { - "xnodes": xnodes, - "wts": weights, - } - - self.common_extra_kwargs = common_extra_kwargs - if dipole_vec is not None: assert dipole_vec.shape == (self.dim, self.tree.nsources) @@ -90,6 +97,22 @@ class FMMLibExpansionWrangler(object): # }}} + @memoize_method + def projection_quad_extra_kwargs(self, level): + common_extra_kwargs = {} + + if self.dim == 3 and self.eqn_letter == "h": + nquad = max(6, int(2.5*self.level_nterms[level])) + from pyfmmlib import legewhts + xnodes, weights = legewhts(nquad, ifwhts=1) + + common_extra_kwargs = { + "xnodes": xnodes, + "wts": weights, + } + + return common_extra_kwargs + # {{{ overridable target lists for the benefit of the QBX FMM def box_target_starts(self): @@ -124,21 +147,28 @@ class FMMLibExpansionWrangler(object): rout = self.get_routine(name, suffix) if self.dim == 2: - return rout + def wrapper(*args, **kwargs): + # not used + kwargs.pop("level_for_projection") + + return rout(*args, **kwargs) else: def wrapper(*args, **kwargs): - kwargs.update(self.common_extra_kwargs) + level_for_projection = kwargs.pop("level_for_projection") + kwargs.update(self.projection_quad_extra_kwargs( + level_for_projection)) + val, ier = rout(*args, **kwargs) if (ier != 0).any(): raise RuntimeError("%s failed with nonzero ier" % name) return val - # Doesn't work in in Py2 - # from functools import update_wrapper - # update_wrapper(wrapper, rout) - return wrapper + # Doesn't work in in Py2 + # from functools import update_wrapper + # update_wrapper(wrapper, rout) + return wrapper def get_direct_eval_routine(self): if self.dim == 2: @@ -224,6 +254,18 @@ class FMMLibExpansionWrangler(object): # {{{ data vector utilities + def expansion_shape(self, nterms): + if self.dim == 2 and self.eqn_letter == "l": + return (nterms+1,) + elif self.dim == 2 and self.eqn_letter == "h": + return (2*nterms+1,) + elif self.dim == 3: + # This is the transpose of the Fortran format, to + # minimize mismatch between C and Fortran orders. + return (2*nterms+1, nterms+1,) + else: + raise ValueError("unsupported dimensionality") + def _expansions_level_starts(self, order_to_size): result = [0] for lev in range(self.tree.nlevels): @@ -231,63 +273,54 @@ class FMMLibExpansionWrangler(object): self.tree.level_start_box_nrs[lev+1] - self.tree.level_start_box_nrs[lev]) - expn_size = order_to_size(self.level_orders[lev]) + expn_size = order_to_size(self.level_nterms[lev]) result.append( result[-1] + expn_size * lev_nboxes) return result - def expansion_shape(self, nterms): - if self.dim == 2 and self.eqn_letter == "l": - return (nterms+1,) - elif self.dim == 2 and self.eqn_letter == "h": - return (2*nterms+1,) - elif self.dim == 3: - # This is the transpose of the Fortran format, to - # minimize mismatch between C and Fortran orders. - return (2*nterms+1, nterms+1,) - else: - raise ValueError("unsupported dimensionality") - - # @memoize_method - # def multipole_expansions_level_starts(self): - # from pytools import product - # return self._expansions_level_starts( - # lambda nterms: product(self.expansion_shape(nterms))) + @memoize_method + def multipole_expansions_level_starts(self): + from pytools import product + return self._expansions_level_starts( + lambda nterms: product(self.expansion_shape(nterms))) - # @memoize_method - # def local_expansions_level_starts(self): - # from pytools import product - # return self._expansions_level_starts( - # lambda nterms: product(self.expansion_shape(nterms))) + @memoize_method + def local_expansions_level_starts(self): + from pytools import product + return self._expansions_level_starts( + lambda nterms: product(self.expansion_shape(nterms))) def multipole_expansions_view(self, mpole_exps, level): box_start, box_stop = self.tree.level_start_box_nrs[level:level+2] - # expn_start, expn_stop = \ - # self.multipole_expansions_level_starts()[level:level+2] - # return (box_start, - # mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1)) - - return box_start, mpole_exps[box_start:box_stop] + expn_start, expn_stop = \ + self.multipole_expansions_level_starts()[level:level+2] + return (box_start, + mpole_exps[expn_start:expn_stop].reshape( + box_stop-box_start, + *self.expansion_shape(self.level_nterms[level]))) def local_expansions_view(self, local_exps, level): box_start, box_stop = self.tree.level_start_box_nrs[level:level+2] - # expn_start, expn_stop = \ - # self.local_expansions_level_starts()[level:level+2] - # return (box_start, - # local_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1)) - - return box_start, local_exps[box_start:box_stop] + expn_start, expn_stop = \ + self.local_expansions_level_starts()[level:level+2] + return (box_start, + local_exps[expn_start:expn_stop].reshape( + box_stop-box_start, + *self.expansion_shape(self.level_nterms[level]))) def multipole_expansion_zeros(self): return np.zeros( - (self.tree.nboxes,) + self.expansion_shape(self.nterms), + self.multipole_expansions_level_starts()[-1], dtype=self.dtype) - local_expansion_zeros = multipole_expansion_zeros + def local_expansion_zeros(self): + return np.zeros( + self.local_expansions_level_starts()[-1], + dtype=self.dtype) def output_zeros(self): if self.ifgrad: @@ -393,7 +426,7 @@ class FMMLibExpansionWrangler(object): rscale=rscale, source=self._get_sources(pslice), center=self.tree.box_centers[:, src_ibox], - nterms=self.nterms, + nterms=self.level_nterms[lev], **kwargs) if ier: @@ -448,7 +481,9 @@ class FMMLibExpansionWrangler(object): rscale2=target_rscale, center2=parent_center, - nterms2=self.nterms, + nterms2=self.level_nterms[target_level], + + level_for_projection=source_level, **kwargs) @@ -548,7 +583,7 @@ class FMMLibExpansionWrangler(object): kwargs["ier"] = ier expn2 = np.zeros( - (ntgt_boxes,) + self.expansion_shape(self.nterms), + (ntgt_boxes,) + self.expansion_shape(self.level_nterms[lev]), dtype=self.dtype) kwargs.update(self.kernel_kwargs) @@ -570,6 +605,9 @@ class FMMLibExpansionWrangler(object): # FIXME: wrong layout, will copy center2=tree.box_centers[:, tgt_ibox_vec], expn2=expn2.T, + + level_for_projection=lev, + **kwargs).T target_local_exps_view[tgt_ibox_vec - target_level_start_ibox] += expn2 @@ -654,7 +692,7 @@ class FMMLibExpansionWrangler(object): rscale=rscale, source=self._get_sources(src_pslice), center=tgt_center, - nterms=self.nterms, + nterms=self.level_nterms[lev], **kwargs) if ier: raise RuntimeError("formta failed") @@ -701,7 +739,9 @@ class FMMLibExpansionWrangler(object): rscale2=target_rscale, center2=tgt_center, - nterms2=self.nterms, + nterms2=self.level_nterms[target_lev], + + level_for_projection=target_lev, **kwargs)[..., 0] diff --git a/test/test_fmm.py b/test/test_fmm.py index f082fcb3a82e040d0365a74f7edd42ebe15c67bc..1e4adf27b014838ab151f374c1e24075951b308d 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -548,10 +548,22 @@ def test_pyfmmlib_fmm(ctx_getter, dims, use_dipoles, helmholtz_k): else: base_nterms = 10 + def fmm_level_to_nterms(lev): + result = base_nterms + + if lev < 3 and helmholtz_k: + # exercise order-varies-by-level capability + result += 5 + + if use_dipoles: + result += 1 + + return result + from boxtree.pyfmmlib_integration import FMMLibExpansionWrangler wrangler = FMMLibExpansionWrangler( trav.tree, helmholtz_k, - nterms=base_nterms + (1 if use_dipoles else 0), + fmm_level_to_nterms=fmm_level_to_nterms, dipole_vec=dipole_vec) from boxtree.fmm import drive_fmm