diff --git a/sumpy/p2p.py b/sumpy/p2p.py index aa9da6eea0d838d43a73c715ac5c4b23b95ea837..160a2c28e6a2b95989846a21396c47911366b386 100644 --- a/sumpy/p2p.py +++ b/sumpy/p2p.py @@ -440,8 +440,7 @@ class P2PFromCSR(P2PBase): def get_kernel(self, max_npoints_in_one_box, gpu=False, nsplit=32): loopy_insns, result_names = self.get_loopy_insns_and_result_names() - arguments = ( - self.get_default_src_tgt_arguments() + arguments = self.get_default_src_tgt_arguments() \ + [ lp.GlobalArg("box_target_starts", None, shape=None), @@ -460,12 +459,8 @@ class P2PFromCSR(P2PBase): lp.GlobalArg("result", None, shape="noutputs, ntargets", dim_tags="sep,C"), lp.TemporaryVariable("tgt_center", shape=(self.dim,)), - lp.TemporaryVariable("local_isrc", - shape=(self.dim, max_npoints_in_one_box)), - lp.TemporaryVariable("local_isrc_strength", - shape=(self.strength_count, max_npoints_in_one_box)), "..." - ]) + ] domains = [ "{[itgt_box]: 0 <= itgt_box < ntgt_boxes}", @@ -479,6 +474,12 @@ class P2PFromCSR(P2PBase): outer_limit = (max_npoints_in_one_box - 1) // nsplit if gpu: + arguments += [ + lp.TemporaryVariable("local_isrc", + shape=(self.dim, max_npoints_in_one_box)), + lp.TemporaryVariable("local_isrc_strength", + shape=(self.strength_count, max_npoints_in_one_box)), + ] domains += [ "{[inner]: 0 <= inner < nsplit}", "{[itgt_offset_outer]: 0 <= itgt_offset_outer <= outer_limit}",