From e8076d1ebc2ff01589d765b2485274c1b9206e78 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 11 Jan 2021 17:45:08 -0600 Subject: [PATCH] Make plot imports in examples optional --- examples/curve-pot.py | 29 +++++++++++++++++++---------- examples/expansion-toys.py | 9 ++++++--- examples/sym-exp-complexity.py | 26 +++++++++++++++----------- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/examples/curve-pot.py b/examples/curve-pot.py index 7ce25a82..7d34ee34 100644 --- a/examples/curve-pot.py +++ b/examples/curve-pot.py @@ -1,7 +1,16 @@ import pyopencl as cl import numpy as np import numpy.linalg as la -import matplotlib.pyplot as pt + +try: + import matplotlib.pyplot as plt +except ModuleNotFoundError: + plt = None + +try: + from mayavi import mlab +except ModuleNotFoundError: + mlab = None def process_kernel(knl, what_operator): @@ -248,15 +257,15 @@ def draw_pot_figure(aspect_ratio, plotval_vol[outlier_flag] = sum( nb[outlier_flag] for nb in neighbors)/len(neighbors) - fp.show_scalar_in_mayavi(scale*plotval_vol, max_val=1) - from mayavi import mlab - mlab.colorbar() - if 1: - mlab.points3d( - native_curve.pos[0], - native_curve.pos[1], - scale*plotval_c, scale_factor=0.02) - mlab.show() + if mlab is not None: + fp.show_scalar_in_mayavi(scale*plotval_vol, max_val=1) + mlab.colorbar() + if 1: + mlab.points3d( + native_curve.pos[0], + native_curve.pos[1], + scale*plotval_c, scale_factor=0.02) + mlab.show() # }}} diff --git a/examples/expansion-toys.py b/examples/expansion-toys.py index 12543b93..97c0eceb 100644 --- a/examples/expansion-toys.py +++ b/examples/expansion-toys.py @@ -2,7 +2,10 @@ import pyopencl as cl import sumpy.toys as t import numpy as np from sumpy.visualization import FieldPlotter -import matplotlib.pyplot as plt +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None def main(): @@ -22,7 +25,7 @@ def main(): fp = FieldPlotter([3, 0], extent=8) - if 0: + if 0 and plt is not None: t.logplot(fp, pt_src, cmap="jet") plt.colorbar() plt.show() @@ -37,7 +40,7 @@ def main(): diff = lexp2 - pt_src print(t.l_inf(diff, 1.2, center=lexp2.center)) - if 1: + if 1 and plt is not None: t.logplot(fp, diff, cmap="jet", vmin=-3, vmax=0) plt.colorbar() plt.show() diff --git a/examples/sym-exp-complexity.py b/examples/sym-exp-complexity.py index bde21c42..49e20ead 100644 --- a/examples/sym-exp-complexity.py +++ b/examples/sym-exp-complexity.py @@ -11,6 +11,10 @@ from sumpy.expansion.multipole import ( HelmholtzConformingVolumeTaylorMultipoleExpansion, ) from sumpy.e2e import E2EFromCSR +try: + import matplotlib.pyplot as plt +except ModuleNotFoundError: + plt = None def find_flops(): @@ -59,28 +63,28 @@ def plot_flops(): flops = [62, 300, 914, 2221, 4567, 8405, 14172, 22538, 34113] filename = "laplace-m2l-complexity-3d.pdf" - elif 0: + elif 1: case = "2D Laplace M2L" orders = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] flops = [36, 99, 193, 319, 476, 665, 889, 1143, 1429, 1747, 2097, 2479, 2893, 3339, 3817, 4327, 4869, 5443, 6049, 6687] filename = "laplace-m2l-complexity-2d.pdf" - elif 1: + elif 0: case = "2D Helmholtz M2L" orders = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] flops = [45, 194, 474, 931, 1650, 2632, 3925, 5591, 7706, 10272] filename = "helmholtz-m2l-complexity-2d.pdf" - import matplotlib.pyplot as plt - plt.rc("font", size=16) - plt.title(case) - plt.ylabel("Flop count") - plt.xlabel("Expansion order") - plt.loglog(orders, flops, "o-") - plt.grid() - plt.tight_layout() - plt.savefig(filename) + if plt is not None: + plt.rc("font", size=16) + plt.title(case) + plt.ylabel("Flop count") + plt.xlabel("Expansion order") + plt.loglog(orders, flops, "o-") + plt.grid() + plt.tight_layout() + plt.savefig(filename) if __name__ == "__main__": -- GitLab