diff --git a/sumpy/p2p.py b/sumpy/p2p.py index 33bc4d70b07aac1371234c5b06a82c854d11afdc..9dd8163a7f3501dfdd4f4e271da7ec0d9897301f 100644 --- a/sumpy/p2p.py +++ b/sumpy/p2p.py @@ -459,19 +459,14 @@ class P2PFromCSR(P2PBase): shape="nstrengths, nsources", dim_tags="sep,C"), lp.GlobalArg("result", None, shape="noutputs, ntargets", dim_tags="sep,C"), - lp.TemporaryVariable("local_isrc_strength", - shape="nstrengths, max_npoints_in_one_box", - address_space=lp.AddressSpace.LOCAL), - lp.TemporaryVariable("local_isrc", - shape=(self.dim, max_npoints_in_one_box), - address_space=lp.AddressSpace.LOCAL), lp.TemporaryVariable("tgt_center", shape=(self.dim,)), "..." ]) loopy_knl = lp.make_kernel([ "{[itgt_box]: 0 <= itgt_box < ntgt_boxes}", - "{[ipoint]: 0 <= ipoint < max_npoints_in_one_box}", + "{[itgt_rel]: 0 <= itgt_rel < max_npoints_in_one_box}", + "{[isrc_rel]: 0 <= isrc_rel < max_npoints_in_one_box}", "{[iknl]: 0 <= iknl < noutputs}", "{[isrc_box]: isrc_box_start <= isrc_box < isrc_box_end}", "{[idim]: 0 <= idim < dim}", @@ -487,8 +482,8 @@ class P2PFromCSR(P2PBase): <> isrc_box_start = source_box_starts[itgt_box] <> isrc_box_end = source_box_starts[itgt_box+1] - for ipoint - <> itgt = ipoint + itgt_start + for itgt_rel + <> itgt = itgt_rel + itgt_start <> cond_itgt = itgt < itgt_end <> acc[iknl] = 0 {id=init_acc, dup=iknl} if cond_itgt @@ -498,12 +493,14 @@ class P2PFromCSR(P2PBase): <> src_ibox = source_box_lists[isrc_box] <> isrc_start = box_source_starts[src_ibox] <> isrc_end = isrc_start + box_source_counts_nonchild[src_ibox] - <> cond_isrc = ipoint < isrc_end - isrc_start - if cond_isrc - local_isrc[idim, ipoint] = sources[idim, - ipoint + isrc_start] {id=load_src, dup=idim} - local_isrc_strength[istrength, ipoint] = strength[istrength, - ipoint + isrc_start] {id=load_charge} + for isrc_rel + <> cond_isrc = isrc_rel < isrc_end - isrc_start + if cond_isrc + <> local_isrc[idim, isrc_rel] = sources[idim, + isrc_rel + isrc_start] {id=load_src, dup=idim} + <> local_isrc_strength[istrength, isrc_rel] = strength[ + istrength, isrc_rel + isrc_start] {id=load_charge} + end end if cond_itgt for isrc @@ -553,6 +550,7 @@ class P2PFromCSR(P2PBase): loopy_knl = lp.tag_inames(loopy_knl, "istrength*:unr") loopy_knl = lp.tag_array_axes(loopy_knl, "targets", "sep,C") loopy_knl = lp.tag_array_axes(loopy_knl, "sources", "sep,C") + for knl in self.target_kernels + self.source_kernels: loopy_knl = knl.prepare_loopy_kernel(loopy_knl) @@ -568,9 +566,12 @@ class P2PFromCSR(P2PBase): knl = lp.split_iname(knl, "itgt_box", 4, outer_tag="g.0") else: knl = lp.tag_inames(knl, {"itgt_box": "g.0"}) - knl = lp.split_iname(knl, "ipoint", max_npoints_in_one_box, + knl = lp.rename_inames(knl, ["isrc_rel"], "itgt_rel", existing_ok=True) + knl = lp.split_iname(knl, "itgt_rel", max_npoints_in_one_box, inner_tag="l.0") - knl = lp.add_inames_for_unused_hw_axes(knl) + knl = lp.set_temporary_address_space(knl, + ["local_isrc", "local_isrc_strength"], lp.AddressSpace.LOCAL) + knl = lp.add_inames_for_unused_hw_axes(knl) knl = self._allow_redundant_execution_of_knl_scaling(knl) knl = lp.set_options(knl,