From e09283357ef01eaaea77f88f97b4a09268d413b7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 17 Feb 2021 18:41:51 -0600
Subject: [PATCH] Transpose: enable check, allow running without matplotlib

---
 examples/transpose.py | 40 ++++++++++++++++++++--------------------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/examples/transpose.py b/examples/transpose.py
index 9b07e2b0..6b06a988 100644
--- a/examples/transpose.py
+++ b/examples/transpose.py
@@ -102,7 +102,7 @@ def transpose_using_cl(ctx, queue, cpu_src, cls):
 
     w, h = cpu_src.shape
     result = numpy.empty((h, w), dtype=cpu_src.dtype)
-    cl.enqueue_read_buffer(queue, a_t_buf, result).wait()
+    cl.enqueue_copy(queue, result, a_t_buf).wait()
 
     a_buf.release()
     a_t_buf.release()
@@ -144,7 +144,7 @@ def benchmark_transpose():
     for dev in ctx.devices:
         assert dev.local_mem_size > 0
 
-    queue = cl.CommandQueue(ctx, 
+    queue = cl.CommandQueue(ctx,
             properties=cl.command_queue_properties.PROFILING_ENABLE)
 
     sizes = [int(((2**i) // 32) * 32)
@@ -186,27 +186,27 @@ def benchmark_transpose():
             a_buf.release()
             a_t_buf.release()
 
-    from matplotlib.pyplot import clf, plot, title, xlabel, ylabel, \
-            savefig, legend, grid
-    for i in range(len(methods)):
-        clf()
-        for j in range(i+1):
-            method = methods[j]
-            name = method.__name__.replace("Transpose", "")
-            plot(sizes, numpy.array(mem_bandwidths[method])/1e9, "o-", label=name)
+    try:
+        from matplotlib.pyplot import clf, plot, title, xlabel, ylabel, \
+                savefig, legend, grid
+    except ModuleNotFoundError:
+        pass
+    else:
+        for i in range(len(methods)):
+            clf()
+            for j in range(i+1):
+                method = methods[j]
+                name = method.__name__.replace("Transpose", "")
+                plot(sizes, numpy.array(mem_bandwidths[method])/1e9, "o-", label=name)
 
-        xlabel("Matrix width/height $N$")
-        ylabel("Memory Bandwidth [GB/s]")
-        legend(loc="best")
-        grid()
+            xlabel("Matrix width/height $N$")
+            ylabel("Memory Bandwidth [GB/s]")
+            legend(loc="best")
+            grid()
 
-        savefig("transpose-benchmark-%d.pdf" % i)
+            savefig("transpose-benchmark-%d.pdf" % i)
 
 
-
-
-
-
-#check_transpose()
+check_transpose()
 benchmark_transpose()
 
-- 
GitLab