diff --git a/boxtree/fmm.py b/boxtree/fmm.py index 2d18a7baf4ebc982b9973200c5d4a356a48d97a5..69fc0f2e3255782e6c121734bc703d4b88d72823 100644 --- a/boxtree/fmm.py +++ b/boxtree/fmm.py @@ -210,7 +210,7 @@ class ExpansionWranglerInterface: box in the tree. """ - def potential_zeros(self): + def output_zeros(self): """Return a potentials array (which must support addition) capable of holding a potential value for each target in the tree. Note that :func:`drive_fmm` makes no assumptions about *potential* other than @@ -254,7 +254,7 @@ class ExpansionWranglerInterface: neighbor sources due to *src_weights*, which use :ref:`csr` and are indexed like *target_boxes*. - :returns: a new potential array, see :meth:`potential_zeros`. + :returns: a new potential array, see :meth:`output_zeros`. """ def multipole_to_local(self, @@ -277,7 +277,7 @@ class ExpansionWranglerInterface: return a new potential array. *starts* and *lists* use :ref:`csr` and *starts* is indexed like *target_boxes*. - :returns: a new potential array, see :meth:`potential_zeros`. + :returns: a new potential array, see :meth:`output_zeros`. """ def form_locals(self, @@ -307,7 +307,7 @@ class ExpansionWranglerInterface: """For each box in *target_boxes*, evaluate the local expansion in *local_exps* and return a new potential array. - :returns: a new potential array, see :meth:`potential_zeros`. + :returns: a new potential array, see :meth:`output_zeros`. """ # }}} diff --git a/boxtree/pyfmmlib_integration.py b/boxtree/pyfmmlib_integration.py index 4fa7f457156614abe73f952bcb5aad83134c328d..78f08d4d7f612ea130e6f3a1e96cb01cc357c89e 100644 --- a/boxtree/pyfmmlib_integration.py +++ b/boxtree/pyfmmlib_integration.py @@ -38,12 +38,14 @@ class HelmholtzExpansionWrangler: by using pyfmmlib. """ - def __init__(self, tree, helmholtz_k, nterms): + def __init__(self, tree, helmholtz_k, nterms, ifgrad=False): self.tree = tree self.helmholtz_k = helmholtz_k self.nterms = nterms self.dtype = np.complex128 + self.ifgrad = ifgrad + self.dim = tree.dimensions common_extra_kwargs = {} @@ -113,10 +115,14 @@ class HelmholtzExpansionWrangler: rout = self.get_routine("potgrad%ddall", "_vec") def wrapper(*args, **kwargs): - kwargs["ifgrad"] = False + kwargs["ifgrad"] = self.ifgrad kwargs["ifhess"] = False pot, grad, hess = rout(*args, **kwargs) - return pot + + if not self.ifgrad: + grad = 0 + + return pot, grad # Doesn't work in in Py2 # from functools import update_wrapper @@ -127,10 +133,13 @@ class HelmholtzExpansionWrangler: rout = self.get_routine("potfld%ddall", "_vec") def wrapper(*args, **kwargs): - kwargs["iffld"] = False + kwargs["iffld"] = self.ifgrad pot, fld = rout(*args, **kwargs) - # grad = -fld - return pot + if self.ifgrad: + grad = -fld + else: + grad = 0 + return pot, grad # Doesn't work in in Py2 # from functools import update_wrapper @@ -145,10 +154,14 @@ class HelmholtzExpansionWrangler: if self.dim == 2: def wrapper(*args, **kwargs): - kwargs["ifgrad"] = False + kwargs["ifgrad"] = self.ifgrad kwargs["ifhess"] = False + pot, grad, hess = rout(*args, **kwargs) - return pot + if not self.ifgrad: + grad = 0 + + return pot, grad # Doesn't work in in Py2 # from functools import update_wrapper @@ -157,12 +170,18 @@ class HelmholtzExpansionWrangler: elif self.dim == 3: def wrapper(*args, **kwargs): - kwargs["iffld"] = False + kwargs["iffld"] = self.ifgrad pot, fld, ier = rout(*args, **kwargs) - # grad = -fld + if (ier != 0).any(): raise RuntimeError("%s failed with nonzero ier" % name) - return pot + + if self.ifgrad: + grad = -fld + else: + grad = 0 + + return pot, grad # Doesn't work in in Py2 # from functools import update_wrapper @@ -236,8 +255,21 @@ class HelmholtzExpansionWrangler: local_expansion_zeros = multipole_expansion_zeros - def potential_zeros(self): - return np.zeros(self.tree.ntargets, dtype=self.dtype) + def output_zeros(self): + if self.ifgrad: + from pytools import make_obj_array + return make_obj_array([ + np.zeros(self.tree.ntargets, self.dtype) + for i in range(1 + self.dim)]) + else: + return np.zeros(self.tree.ntargets, self.dtype) + + def add_potgrad_onto_output(self, output, output_slice, pot, grad): + if self.ifgrad: + output[0, output_slice] += pot + output[1:, output_slice] += grad + else: + output[output_slice] += pot # }}} @@ -327,7 +359,7 @@ class HelmholtzExpansionWrangler: def eval_direct(self, target_boxes, neighbor_sources_starts, neighbor_sources_lists, src_weights): - pot = self.potential_zeros() + output = self.output_zeros() ev = self.get_direct_eval_routine() @@ -337,7 +369,10 @@ class HelmholtzExpansionWrangler: if tgt_pslice.stop - tgt_pslice.start == 0: continue - tgt_result = np.zeros(tgt_pslice.stop - tgt_pslice.start, self.dtype) + #tgt_result = np.zeros(tgt_pslice.stop - tgt_pslice.start, self.dtype) + tgt_pot_result = 0 + tgt_grad_result = 0 + start, end = neighbor_sources_starts[itgt_box:itgt_box+2] for src_ibox in neighbor_sources_lists[start:end]: src_pslice = self._get_source_slice(src_ibox) @@ -345,16 +380,18 @@ class HelmholtzExpansionWrangler: if src_pslice.stop - src_pslice.start == 0: continue - tmp_pot = ev( + tmp_pot, tmp_grad = ev( sources=self._get_sources(src_pslice), charge=src_weights[src_pslice], targets=self._get_targets(tgt_pslice), zk=self.helmholtz_k) - tgt_result += tmp_pot + tgt_pot_result += tmp_pot + tgt_grad_result += tmp_grad - pot[tgt_pslice] = tgt_result + self.add_potgrad_onto_output( + output, tgt_pslice, tgt_pot_result, tgt_grad_result) - return pot + return output def multipole_to_local(self, level_start_target_or_target_parent_box_nrs, @@ -398,7 +435,7 @@ class HelmholtzExpansionWrangler: def eval_multipoles(self, level_start_target_box_nrs, target_boxes, sep_smaller_nonsiblings_by_level, mpole_exps): - pot = self.potential_zeros() + output = self.output_zeros() rscale = 1 @@ -412,18 +449,21 @@ class HelmholtzExpansionWrangler: continue tgt_pot = 0 + tgt_grad = 0 start, end = ssn.starts[itgt_box:itgt_box+2] for src_ibox in ssn.lists[start:end]: - tmp_pot = mpeval(self.helmholtz_k, rscale, self. + tmp_pot, tmp_grad = mpeval(self.helmholtz_k, rscale, self. tree.box_centers[:, src_ibox], mpole_exps[src_ibox], self._get_targets(tgt_pslice)) tgt_pot = tgt_pot + tmp_pot + tgt_grad = tgt_grad + tmp_grad - pot[tgt_pslice] += tgt_pot + self.add_potgrad_onto_output( + output, tgt_pslice, tgt_pot, tgt_grad) - return pot + return output def form_locals(self, level_start_target_or_target_parent_box_nrs, @@ -487,7 +527,7 @@ class HelmholtzExpansionWrangler: return local_exps def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): - pot = self.potential_zeros() + output = self.output_zeros() rscale = 1 # FIXME taeval = self.get_expn_eval_routine("ta") @@ -498,10 +538,11 @@ class HelmholtzExpansionWrangler: if tgt_pslice.stop - tgt_pslice.start == 0: continue - tmp_pot = taeval(self.helmholtz_k, rscale, + tmp_pot, tmp_grad = taeval(self.helmholtz_k, rscale, self.tree.box_centers[:, tgt_ibox], local_exps[tgt_ibox], self._get_targets(tgt_pslice)) - pot[tgt_pslice] += tmp_pot + self.add_potgrad_onto_output( + output, tgt_pslice, tmp_pot, tmp_grad) - return pot + return output