From cd7b54f8b553cf635c29fa86e8aa31c62d5831fb Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sat, 20 Aug 2022 21:22:05 +0300
Subject: [PATCH] port examples to arraycontext

---
 examples/curve-pot.py          | 82 ++++++++++++++++++++--------------
 examples/expansion-toys.py     | 35 ++++++++++-----
 examples/sym-exp-complexity.py | 19 +++++---
 3 files changed, 85 insertions(+), 51 deletions(-)

diff --git a/examples/curve-pot.py b/examples/curve-pot.py
index 810990ad..8d1a91a2 100644
--- a/examples/curve-pot.py
+++ b/examples/curve-pot.py
@@ -1,16 +1,22 @@
-import pyopencl as cl
 import numpy as np
 import numpy.linalg as la
 
+import pyopencl as cl
+
 try:
     import matplotlib.pyplot as plt
-except ModuleNotFoundError:
-    plt = None
+    USE_MATPLOTLIB = True
+except ImportError:
+    USE_MATPLOTLIB = False
 
 try:
     from mayavi import mlab
-except ModuleNotFoundError:
-    mlab = None
+    USE_MAYAVI = True
+except ImportError:
+    USE_MAYAVI = False
+
+import logging
+logging.basicConfig(level=logging.INFO)
 
 
 def process_kernel(knl, what_operator):
@@ -45,17 +51,16 @@ def draw_pot_figure(aspect_ratio,
         ovsmp_center_exp=0.66,
         force_center_side=None):
 
-    import logging
-    logging.basicConfig(level=logging.INFO)
-
     if novsmp is None:
         novsmp = 4*nsrc
 
     if what_operator_lpot is None:
         what_operator_lpot = what_operator
 
+    from sumpy.array_context import PyOpenCLArrayContext
     ctx = cl.create_some_context()
     queue = cl.CommandQueue(ctx)
+    actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
 
     # {{{ make plot targets
 
@@ -86,7 +91,7 @@ def draw_pot_figure(aspect_ratio,
         knl_kwargs = {}
 
     vol_source_knl, vol_target_knl = process_kernel(knl, what_operator)
-    p2p = P2P(ctx, source_kernels=(vol_source_knl,),
+    p2p = P2P(actx.context, source_kernels=(vol_source_knl,),
             target_kernels=(vol_target_knl,),
             exclude_self=False,
             value_dtypes=np.complex128)
@@ -94,8 +99,10 @@ def draw_pot_figure(aspect_ratio,
     lpot_source_knl, lpot_target_knl = process_kernel(knl, what_operator_lpot)
 
     from sumpy.qbx import LayerPotential
-    lpot = LayerPotential(ctx, expansion=expn_class(knl, order=order),
-            source_kernels=(lpot_source_knl,), target_kernels=(lpot_target_knl,),
+    lpot = LayerPotential(actx.context,
+            expansion=expn_class(knl, order=order),
+            source_kernels=(lpot_source_knl,),
+            target_kernels=(lpot_target_knl,),
             value_dtypes=np.complex128)
 
     # }}}
@@ -142,8 +149,9 @@ def draw_pot_figure(aspect_ratio,
             + center_side[:, np.newaxis]
             * center_dist*native_curve.normal)
 
-    #native_curve.plot()
-    #plt.show()
+    if 0:
+        native_curve.plot()
+        plt.show()
 
     volpot_kwargs = knl_kwargs.copy()
     lpot_kwargs = knl_kwargs.copy()
@@ -169,7 +177,9 @@ def draw_pot_figure(aspect_ratio,
 
         def apply_lpot(x):
             xovsmp = np.dot(fim, x)
-            evt, (y,) = lpot(queue, native_curve.pos, ovsmp_curve.pos,
+            evt, (y,) = lpot(actx.queue,
+                    native_curve.pos,
+                    ovsmp_curve.pos,
                     centers,
                     [xovsmp * ovsmp_curve.speed * ovsmp_weights],
                     expansion_radii=np.ones(centers.shape[1]),
@@ -191,10 +201,14 @@ def draw_pot_figure(aspect_ratio,
     mode_nr = 0
     density = np.cos(mode_nr*2*np.pi*native_t).astype(np.complex128)
     ovsmp_density = np.cos(mode_nr*2*np.pi*ovsmp_t).astype(np.complex128)
-    evt, (vol_pot,) = p2p(queue, fp.points, native_curve.pos,
+    evt, (vol_pot,) = p2p(actx.queue,
+            fp.points,
+            native_curve.pos,
             [native_curve.speed*native_weights*density], **volpot_kwargs)
 
-    evt, (curve_pot,) = lpot(queue, native_curve.pos, ovsmp_curve.pos,
+    evt, (curve_pot,) = lpot(actx.queue,
+            native_curve.pos,
+            ovsmp_curve.pos,
             centers,
             [ovsmp_density * ovsmp_curve.speed * ovsmp_weights],
             expansion_radii=np.ones(centers.shape[1]),
@@ -202,7 +216,7 @@ def draw_pot_figure(aspect_ratio,
 
     # }}}
 
-    if 0:
+    if USE_MATPLOTLIB:
         # {{{ plot on-surface potential in 2D
 
         plt.plot(curve_pot, label="pot")
@@ -216,7 +230,7 @@ def draw_pot_figure(aspect_ratio,
         ("potential", vol_pot.real)
         ])
 
-    if 0:
+    if USE_MATPLOTLIB:
         # {{{ 2D false-color plot
 
         plt.clf()
@@ -230,12 +244,8 @@ def draw_pot_figure(aspect_ratio,
         # close the curve
         plt.plot(src[-1::-len(src)+1, 0], src[-1::-len(src)+1, 1], "o-k")
 
-        #plt.gca().set_aspect("equal", "datalim")
         cb = plt.colorbar(shrink=0.9)
         cb.set_label(r"$\log_{10}(\mathdefault{Error})$")
-        #from matplotlib.ticker import NullFormatter
-        #plt.gca().xaxis.set_major_formatter(NullFormatter())
-        #plt.gca().yaxis.set_major_formatter(NullFormatter())
         fp.set_matplotlib_limits()
 
         # }}}
@@ -261,7 +271,7 @@ def draw_pot_figure(aspect_ratio,
             plotval_vol[outlier_flag] = sum(
                     nb[outlier_flag] for nb in neighbors)/len(neighbors)
 
-        if mlab is not None:
+        if USE_MAYAVI:
             fp.show_scalar_in_mayavi(scale*plotval_vol, max_val=1)
             mlab.colorbar()
             if 1:
@@ -275,17 +285,23 @@ def draw_pot_figure(aspect_ratio,
 
 
 if __name__ == "__main__":
-    draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=(35+4j)*0.3,
+    draw_pot_figure(
+            aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=(35+4j)*0.3,
             what_operator="D", what_operator_lpot="D", force_center_side=1)
+    if USE_MATPLOTLIB:
+        plt.savefig("eigvals-ext-nsrc100-novsmp100.pdf")
+        plt.clf()
 
-#    plt.savefig("eigvals-ext-nsrc100-novsmp100.pdf")
-    #plt.clf()
-    #draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=0,
-    #        what_operator="D", what_operator_lpot="D", force_center_side=-1)
-    #plt.savefig("eigvals-int-nsrc100-novsmp100.pdf")
-    #plt.clf()
-    #draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=200, helmholtz_k=0,
-    #        what_operator="D", what_operator_lpot="D", force_center_side=-1)
-    #plt.savefig("eigvals-int-nsrc100-novsmp200.pdf")
+    # draw_pot_figure(
+    #         aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=0,
+    #         what_operator="D", what_operator_lpot="D", force_center_side=-1)
+    # plt.savefig("eigvals-int-nsrc100-novsmp100.pdf")
+    # plt.clf()
+
+    # draw_pot_figure(
+    #         aspect_ratio=1, nsrc=100, novsmp=200, helmholtz_k=0,
+    #         what_operator="D", what_operator_lpot="D", force_center_side=-1)
+    # plt.savefig("eigvals-int-nsrc100-novsmp200.pdf")
+    # plt.clf()
 
 # vim: fdm=marker
diff --git a/examples/expansion-toys.py b/examples/expansion-toys.py
index e774b17a..48647d0e 100644
--- a/examples/expansion-toys.py
+++ b/examples/expansion-toys.py
@@ -1,21 +1,32 @@
+import numpy as np
+
 import pyopencl as cl
+
 import sumpy.toys as t
-import numpy as np
 from sumpy.visualization import FieldPlotter
+from sumpy.kernel import (      # noqa: F401
+        YukawaKernel,
+        HelmholtzKernel,
+        LaplaceKernel)
+
 try:
     import matplotlib.pyplot as plt
-except ModuleNotFoundError:
-    plt = None
+    USE_MATPLOTLIB = True
+except ImportError:
+    USE_MATPLOTLIB = False
 
 
 def main():
-    from sumpy.kernel import (  # noqa: F401
-            YukawaKernel, HelmholtzKernel, LaplaceKernel)
+    from sumpy.array_context import PyOpenCLArrayContext
+    ctx = cl.create_some_context()
+    queue = cl.CommandQueue(ctx)
+    actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
+
     tctx = t.ToyContext(
-            cl.create_some_context(),
-            #LaplaceKernel(2),
+            actx.context,
+            # LaplaceKernel(2),
             YukawaKernel(2), extra_kernel_kwargs={"lam": 5},
-            #HelmholtzKernel(2), extra_kernel_kwargs={"k": 0.3},
+            # HelmholtzKernel(2), extra_kernel_kwargs={"k": 0.3},
             )
 
     pt_src = t.PointSources(
@@ -25,7 +36,7 @@ def main():
 
     fp = FieldPlotter([3, 0], extent=8)
 
-    if 0 and plt is not None:
+    if USE_MATPLOTLIB:
         t.logplot(fp, pt_src, cmap="jet")
         plt.colorbar()
         plt.show()
@@ -35,12 +46,12 @@ def main():
     lexp = t.local_expand(mexp, [3, 0])
     lexp2 = t.local_expand(lexp, [3, 1], 3)
 
-    #diff = mexp - pt_src
-    #diff = mexp2 - pt_src
+    # diff = mexp - pt_src
+    # diff = mexp2 - pt_src
     diff = lexp2 - pt_src
 
     print(t.l_inf(diff, 1.2, center=lexp2.center))
-    if 1 and plt is not None:
+    if USE_MATPLOTLIB:
         t.logplot(fp, diff, cmap="jet", vmin=-3, vmax=0)
         plt.colorbar()
         plt.show()
diff --git a/examples/sym-exp-complexity.py b/examples/sym-exp-complexity.py
index ae91a932..779bfc36 100644
--- a/examples/sym-exp-complexity.py
+++ b/examples/sym-exp-complexity.py
@@ -1,6 +1,8 @@
 import numpy as np
-import pyopencl as cl
 import loopy as lp
+
+import pyopencl as cl
+
 from sumpy.kernel import LaplaceKernel, HelmholtzKernel
 from sumpy.expansion.local import (
         LinearPDEConformingVolumeTaylorLocalExpansion,
@@ -9,14 +11,19 @@ from sumpy.expansion.multipole import (
         LinearPDEConformingVolumeTaylorMultipoleExpansion,
         )
 from sumpy.e2e import E2EFromCSR
+
 try:
     import matplotlib.pyplot as plt
-except ModuleNotFoundError:
-    plt = None
+    USE_MATPLOTLIB = True
+except ImportError:
+    USE_MATPLOTLIB = False
 
 
 def find_flops():
+    from sumpy.array_context import PyOpenCLArrayContext
     ctx = cl.create_some_context()
+    queue = cl.CommandQueue(ctx)
+    actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
 
     if 0:
         knl = LaplaceKernel(2)
@@ -35,7 +42,7 @@ def find_flops():
         print(order)
         m_expn = m_expn_cls(knl, order)
         l_expn = l_expn_cls(knl, order)
-        m2l = E2EFromCSR(ctx, m_expn, l_expn)
+        m2l = E2EFromCSR(actx.context, m_expn, l_expn)
 
         loopy_knl = m2l.get_kernel()
         loopy_knl = lp.add_and_infer_dtypes(
@@ -74,7 +81,7 @@ def plot_flops():
         flops = [45, 194, 474, 931, 1650, 2632, 3925, 5591, 7706, 10272]
         filename = "helmholtz-m2l-complexity-2d.pdf"
 
-    if plt is not None:
+    if USE_MATPLOTLIB:
         plt.rc("font", size=16)
         plt.title(case)
         plt.ylabel("Flop count")
@@ -86,5 +93,5 @@ def plot_flops():
 
 
 if __name__ == "__main__":
-    #find_flops()
+    # find_flops()
     plot_flops()
-- 
GitLab