From e7b455aa40dbc6a3965dca4add1f351c21b90585 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Thu, 3 Aug 2023 19:28:20 +0300
Subject: [PATCH] toys: only send cl arrays to kernels

---
 sumpy/toys.py | 96 +++++++++++++++++++++++++++++++--------------------
 1 file changed, 59 insertions(+), 37 deletions(-)

diff --git a/sumpy/toys.py b/sumpy/toys.py
index 9fa67cbc..931e96e5 100644
--- a/sumpy/toys.py
+++ b/sumpy/toys.py
@@ -201,48 +201,59 @@ class ToyContext:
 # {{{ helpers
 
 def _p2e(psource, center, rscale, order, p2e, expn_class, expn_kwargs):
-    source_boxes = np.array([0], dtype=np.int32)
-    box_source_starts = np.array([0], dtype=np.int32)
-    box_source_counts_nonchild = np.array(
-            [psource.points.shape[-1]], dtype=np.int32)
-
     toy_ctx = psource.toy_ctx
+    queue = toy_ctx.queue
+
+    source_boxes = cl.array.to_device(
+        queue, np.array([0], dtype=np.int32))
+    box_source_starts = cl.array.to_device(
+        queue, np.array([0], dtype=np.int32))
+    box_source_counts_nonchild = cl.array.to_device(
+        queue, np.array([psource.points.shape[-1]], dtype=np.int32))
+
     center = np.asarray(center)
-    centers = np.array(center, dtype=np.float64).reshape(
-            toy_ctx.kernel.dim, 1)
+    centers = cl.array.to_device(
+        queue,
+        np.array(center, dtype=np.float64).reshape(toy_ctx.kernel.dim, 1))
 
     evt, (coeffs,) = p2e(toy_ctx.queue,
             source_boxes=source_boxes,
             box_source_starts=box_source_starts,
             box_source_counts_nonchild=box_source_counts_nonchild,
             centers=centers,
-            sources=psource.points,
-            strengths=(psource.weights,),
+            sources=cl.array.to_device(queue, psource.points),
+            strengths=(cl.array.to_device(queue, psource.weights),),
             rscale=rscale,
             nboxes=1,
             tgt_base_ibox=0,
 
-            #flags="print_hl_cl",
             out_host=True,
             **toy_ctx.extra_source_and_kernel_kwargs)
 
-    return expn_class(toy_ctx, center, rscale, order, coeffs[0],
+    return expn_class(toy_ctx, center, rscale, order, coeffs[0].get(queue),
             derived_from=psource, **expn_kwargs)
 
 
 def _e2p(psource, targets, e2p):
-    ntargets = targets.shape[-1]
+    toy_ctx = psource.toy_ctx
+    queue = toy_ctx.queue
 
-    boxes = np.array([0], dtype=np.int32)
+    ntargets = targets.shape[-1]
+    boxes = cl.array.to_device(
+        queue, np.array([0], dtype=np.int32))
+    box_target_starts = cl.array.to_device(
+        queue, np.array([0], dtype=np.int32))
+    box_target_counts_nonchild = cl.array.to_device(
+        queue, np.array([ntargets], dtype=np.int32))
 
-    box_target_starts = np.array([0], dtype=np.int32)
-    box_target_counts_nonchild = np.array([ntargets], dtype=np.int32)
+    centers = cl.array.to_device(
+        queue,
+        np.array(psource.center, dtype=np.float64).reshape(toy_ctx.kernel.dim, 1))
 
-    toy_ctx = psource.toy_ctx
-    centers = np.array(psource.center, dtype=np.float64).reshape(
-            toy_ctx.kernel.dim, 1)
+    from pytools.obj_array import make_obj_array
+    from sumpy.tools import vector_to_device
 
-    coeffs = np.array([psource.coeffs])
+    coeffs = cl.array.to_device(queue, np.array([psource.coeffs]))
     evt, (pot,) = e2p(
             toy_ctx.queue,
             src_expansions=coeffs,
@@ -252,31 +263,38 @@ def _e2p(psource, targets, e2p):
             box_target_counts_nonchild=box_target_counts_nonchild,
             centers=centers,
             rscale=psource.rscale,
-            targets=targets,
-            #flags="print_hl_cl",
-            out_host=True, **toy_ctx.extra_kernel_kwargs)
+            targets=vector_to_device(queue, make_obj_array(targets)),
+
+            out_host=True,
+            **toy_ctx.extra_kernel_kwargs)
 
-    return pot
+    return pot.get(queue)
 
 
 def _e2e(psource, to_center, to_rscale, to_order, e2e, expn_class, expn_kwargs):
     toy_ctx = psource.toy_ctx
-
-    target_boxes = np.array([1], dtype=np.int32)
-    src_box_starts = np.array([0, 1], dtype=np.int32)
-    src_box_lists = np.array([0], dtype=np.int32)
-
-    centers = (np.array(
+    queue = toy_ctx.queue
+
+    target_boxes = cl.array.to_device(
+        queue, np.array([1], dtype=np.int32))
+    src_box_starts = cl.array.to_device(
+        queue, np.array([0, 1], dtype=np.int32))
+    src_box_lists = cl.array.to_device(
+        queue, np.array([0], dtype=np.int32))
+
+    centers = cl.array.to_device(
+        queue,
+        np.array(
             [
                 # box 0: source
                 psource.center,
 
                 # box 1: target
                 to_center,
-                ],
-            dtype=np.float64)).T.copy()
-
-    coeffs = np.array([psource.coeffs])
+            ],
+            dtype=np.float64).T.copy()
+        )
+    coeffs = cl.array.to_device(queue, np.array([psource.coeffs]))
 
     evt, (to_coeffs,) = e2e(
             toy_ctx.queue,
@@ -294,10 +312,10 @@ def _e2e(psource, to_center, to_rscale, to_order, e2e, expn_class, expn_kwargs):
             src_rscale=psource.rscale,
             tgt_rscale=to_rscale,
 
-            #flags="print_hl_cl",
             out_host=True, **toy_ctx.extra_kernel_kwargs)
 
-    return expn_class(toy_ctx, to_center, to_rscale, to_order, to_coeffs[1],
+    return expn_class(
+            toy_ctx, to_center, to_rscale, to_order, to_coeffs[1].get(queue),
             derived_from=psource, **expn_kwargs)
 
 # }}}
@@ -443,12 +461,16 @@ class PointSources(PotentialSource):
         self._center = center
 
     def eval(self, targets: np.ndarray) -> np.ndarray:
+        queue = self.toy_ctx.queue
         evt, (potential,) = self.toy_ctx.get_p2p()(
-                self.toy_ctx.queue, targets, self.points, [self.weights],
+                queue,
+                cl.array.to_device(queue, targets),
+                cl.array.to_device(queue, self.points),
+                [cl.array.to_device(queue, self.weights)],
                 out_host=True,
                 **self.toy_ctx.extra_source_and_kernel_kwargs)
 
-        return potential
+        return potential.get(queue)
 
     @property
     def center(self):
-- 
GitLab