diff --git a/pytential/qbx/direct.py b/pytential/qbx/direct.py index 6dc5cd9abbb7319d0cd7a4029a3a2b22b6a710e5..70fa0d1a9e4bbdc6b01a00ca66ab4bc1ce22ab5a 100644 --- a/pytential/qbx/direct.py +++ b/pytential/qbx/direct.py @@ -25,62 +25,91 @@ THE SOFTWARE. import loopy as lp import numpy as np -from sumpy.qbx import LayerPotential as LayerPotentialBase +from sumpy.qbx import LayerPotentialBase # {{{ qbx applier on a target/center subset class LayerPotentialOnTargetAndCenterSubset(LayerPotentialBase): - def get_compute_a_and_b_vecs(self): - return """ - <> icenter = qbx_center_numbers[itgt] - <> itgt_overall = qbx_tgt_numbers[itgt] - for idim - <> a[idim] = center[idim,icenter] - src[idim,isrc] {id=compute_a} - <> b[idim] = tgt[idim,itgt_overall] - center[idim,icenter] \ - {id=compute_b} - <> rscale = expansion_radii[icenter] - end - """ - - def get_src_tgt_arguments(self): - return [ + default_name = "qbx_tgt_ctr_subset" + + def get_kernel(self): + loopy_insns, result_names = self.get_loopy_insns_and_result_names() + kernel_exprs = self.get_kernel_exprs(result_names) + + from sumpy.tools import gather_loopy_source_arguments + arguments = ( + gather_loopy_source_arguments(self.kernels) + + [ lp.GlobalArg("src", None, shape=(self.dim, "nsources"), order="C"), lp.GlobalArg("tgt", None, shape=(self.dim, "ntargets_total"), order="C"), lp.GlobalArg("center", None, - shape=(self.dim, "ncenters_total"), order="C"), - lp.GlobalArg("expansion_radii", None, shape="ncenters_total"), - lp.GlobalArg("qbx_tgt_numbers", None, shape="ntargets"), - lp.GlobalArg("qbx_center_numbers", None, shape="ntargets"), + shape=(self.dim, "ncenters_total"), dim_tags="sep,C"), + lp.GlobalArg("expansion_radii", None, + shape="ncenters_total"), + lp.GlobalArg("qbx_tgt_numbers", None, + shape="ntargets"), + lp.GlobalArg("qbx_center_numbers", None, + shape="ntargets"), lp.ValueArg("nsources", np.int32), lp.ValueArg("ntargets", np.int32), lp.ValueArg("ntargets_total", np.int32), - lp.ValueArg("ncenters_total", np.int32), - ] - - def get_input_and_output_arguments(self): - return [ - lp.GlobalArg("strength_%d" % i, None, shape="nsources", order="C") - for i in range(self.strength_count) - ]+[ - lp.GlobalArg("result_%d" % i, None, shape="ntargets_total", - order="C") - for i in range(len(self.kernels)) - ] - - def get_result_store_instructions(self): - return [ - """ - result_KNLIDX[itgt_overall] = \ - knl_KNLIDX_scaling*simul_reduce(\ - sum, isrc, pair_result_KNLIDX) {inames=itgt} - """.replace("KNLIDX", str(iknl)) - for iknl in range(len(self.expansions)) - ] + lp.ValueArg("ncenters_total", np.int32)] + + [lp.GlobalArg("strength_%d" % i, None, + shape="nsources", order="C") + for i in range(self.strength_count)] + + [lp.GlobalArg("result_%d" % i, self.value_dtypes[i], + shape="ntargets_total", order="C") + for i in range(len(self.kernels))]) -# }}} + loopy_knl = lp.make_kernel([ + "{[itgt]: 0 <= itgt < ntargets}", + "{[isrc]: 0 <= isrc < nsources}", + "{[idim]: 0 <= idim < dim}" + ], + self.get_kernel_scaling_assignments() + + ["for itgt, isrc"] + + [""" + <> icenter = qbx_center_numbers[itgt] + <> itgt_overall = qbx_tgt_numbers[itgt] + + <> a[idim] = center[idim, icenter] - src[idim, isrc] \ + {dup=idim} + <> b[idim] = tgt[idim, itgt_overall] - center[idim, icenter] \ + {dup=idim} + <> rscale = expansion_radii[icenter] + """] + + loopy_insns + kernel_exprs + + [""" + result_{i}[itgt_overall] = knl_{i}_scaling * \ + simul_reduce(sum, isrc, pair_result_{i}) \ + {{inames=itgt}} + """.format(i=iknl) + for iknl in range(len(self.expansions))] + + ["end"], + arguments, + name=self.name, + assumptions="ntargets>=1 and nsources>=1", + fixed_parameters=dict(dim=self.dim)) + + loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr") + for expn in self.expansions: + loopy_knl = expn.prepare_loopy_kernel(loopy_knl) + return loopy_knl + + def __call__(self, queue, targets, sources, centers, strengths, expansion_radii, + **kwargs): + knl = self.get_cached_optimized_kernel() + + for i, dens in enumerate(strengths): + kwargs["strength_%d" % i] = dens + + return knl(queue, src=sources, tgt=targets, center=centers, + expansion_radii=expansion_radii, **kwargs) + +# }}} # vim: foldmethod=marker