From b492dd94b413d8a73ae2bef04b71e65b6aaa8d8c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 30 Mar 2018 20:57:57 -0500
Subject: [PATCH] Add memory access pattern visualizer

---
 contrib/mem-pattern-explorer/pattern_vis.py | 206 ++++++++++++++++++++
 1 file changed, 206 insertions(+)
 create mode 100644 contrib/mem-pattern-explorer/pattern_vis.py

diff --git a/contrib/mem-pattern-explorer/pattern_vis.py b/contrib/mem-pattern-explorer/pattern_vis.py
new file mode 100644
index 000000000..10eff2306
--- /dev/null
+++ b/contrib/mem-pattern-explorer/pattern_vis.py
@@ -0,0 +1,206 @@
+import numpy as np
+
+# Inspired by a visualization used in the Halide tutorial
+# https://www.youtube.com/watch?v=3uiEyEKji0M
+
+
+def div_ceil(nr, dr):
+    return -(-nr // dr)
+
+
+def product(iterable):
+    from operator import mul
+    from functools import reduce
+    return reduce(mul, iterable, 1)
+
+
+class ArrayAccessPatternContext:
+    def __init__(self, gsize, lsize, subgroup_size=32):
+        self.lsize = lsize
+        self.gsize = gsize
+        self.subgroup_size = subgroup_size
+        self.timestamp = 0
+
+        self.ind_length = len(gsize) + len(lsize)
+
+        self.arrays = []
+
+    def l(self, index):  # noqa: E743
+        subscript = [np.newaxis] * self.ind_length
+        subscript[len(self.gsize) + index] = slice(None)
+
+        return np.arange(self.lsize[index])[tuple(subscript)]
+
+    def g(self, index):
+        subscript = [np.newaxis] * self.ind_length
+        subscript[index] = slice(None)
+
+        return np.arange(self.gsize[index])[tuple(subscript)]
+
+    def nsubgroups(self):
+        return div_ceil(product(self.lsize), self.subgroup_size)
+
+    def animate(self, f):
+        import matplotlib.pyplot as plt
+        import matplotlib.animation as animation
+
+        fig = plt.figure()
+
+        plots = []
+        for iary, ary in enumerate(self.arrays):
+            ax = fig.add_subplot(1, len(self.arrays), 1+iary)
+            ax.set_title(ary.name)
+            plots.append(ary.plot(ax))
+
+        def data_gen():
+            for _ in f():
+                self.tick()
+
+                for ary, plot in zip(self.arrays, plots):
+                    plot.set_array(ary.get_plot_data())
+
+                fig.canvas.draw()
+                yield plots
+
+        # must be kept alive until after plt.show()
+        return animation.FuncAnimation(
+                fig, lambda x: x, data_gen, blit=False, interval=200, repeat=True)
+
+    def tick(self):
+        self.timestamp += 1
+
+
+class Array:
+    def __init__(self, ctx, name, shape, strides, elements_per_row=None):
+        # Each array element stores a tuple:
+        # (timestamp, subgroup, g0, g1, g2, ) of last acccess
+
+        assert len(shape) == len(strides)
+
+        self.nattributes = 2+len(ctx.gsize)
+
+        if elements_per_row is None:
+            if len(shape) > 1:
+                minstride = min(strides)
+                for sh_i, st_i in zip(shape, strides):
+                    if st_i == minstride:
+                        elements_per_row = sh_i
+                        break
+        else:
+            elements_per_row = 256
+
+        self.array = np.zeros((product(shape), self.nattributes,), dtype=np.int32)
+
+        self.ctx = ctx
+        self.name = name
+        self.shape = shape
+        self.strides = strides
+        self.elements_per_row = elements_per_row
+
+        ctx.arrays.append(self)
+
+    def __getitem__(self, index):
+        if not isinstance(index, tuple):
+            index = (index,)
+
+        assert len(index) == len(self.shape)
+
+        all_subscript = (np.newaxis,) * self.ctx.ind_length
+
+        def reshape_ind(ind):
+            if not isinstance(ind, np.ndarray):
+                return ind[all_subscript]
+
+            else:
+                assert len(ind.shape) == self.ctx.ind_length
+
+        lin_index = sum(
+                ind_i * stride_i
+                for ind_i, stride_i in zip(index, self.strides))
+
+        self.array[lin_index, 0] = self.ctx.timestamp
+        for i, glength in enumerate(self.ctx.gsize):
+            if lin_index.shape[i] > 1:
+                self.array[lin_index, 2+i] = self.ctx.g(i)
+
+        workitem_index = 0
+        for i in range(len(self.ctx.lsize))[::-1]:
+            workitem_index = (
+                    workitem_index * self.ctx.lsize[i]
+                    + self.ctx.l(i))
+        subgroup = workitem_index//self.ctx.subgroup_size
+        self.array[lin_index, 1] = subgroup
+
+    def __setitem__(self, index, value):
+        self.__getitem__(index)
+
+    def get_plot_data(self):
+        nelements = self.array.shape[0]
+        base_shape = (
+                div_ceil(nelements, self.elements_per_row),
+                self.elements_per_row,)
+        shaped_array = np.zeros(
+                base_shape + (self.nattributes,),
+                dtype=np.float32)
+        shaped_array.reshape(-1, self.nattributes)[:nelements] = self.array
+
+        modulation = np.exp(-(self.ctx.timestamp-shaped_array[:, :, 0]))
+
+        subgroup = shaped_array[:, :, 1]/(self.ctx.nsubgroups()-1)
+
+        rgb_array = np.zeros(base_shape + (3,))
+        if 1:
+            if len(self.ctx.gsize) >= 1:
+                # g.0 -> red
+                rgb_array[:, :, 0] = shaped_array[:, :, 2]/(self.ctx.gsize[0]-1)
+            if len(self.ctx.gsize) >= 2:
+                # g.1 -> blue
+                rgb_array[:, :, 2] = shaped_array[:, :, 3]/(self.ctx.gsize[1]-1)
+        if 1:
+            rgb_array[:, :, 1] = subgroup
+
+        return rgb_array*modulation[:, :, np.newaxis]
+
+    def plot(self, ax, **kwargs):
+        return ax.imshow(
+                self.get_plot_data(), interpolation="nearest",
+                **kwargs)
+
+
+def show_example():
+    n = 2**7
+    n16 = div_ceil(n, 16)
+    ctx = ArrayAccessPatternContext(gsize=(n16, n16), lsize=(16, 16))
+    in0 = Array(ctx, "in0", (n, n), (n, 1))
+
+    if 0:
+        # knl a
+        i_inner = ctx.l(1)
+        i_outer = ctx.g(1)
+        k_inner = ctx.l(0)
+
+        def f():
+            for k_outer in range(n16):
+                in0[i_inner + i_outer*16, k_inner + k_outer*16]
+                yield
+    else:
+        # knl b
+        j_inner = ctx.l(0)
+        j_outer = ctx.g(0)
+        k_inner = ctx.l(1)
+
+        def f():
+            for k_outer in range(n16):
+                in0[k_inner + k_outer*16, j_inner + j_outer*16]
+                yield
+
+    ani = ctx.animate(f)
+    import matplotlib.pyplot as plt
+    if 1:
+        plt.show()
+    else:
+        ani.save("access.mp4")
+
+
+if __name__ == "__main__":
+    show_example()
-- 
GitLab