diff --git a/examples/demo_mandelbrot.py b/examples/demo_mandelbrot.py index eb53f41640072466af40d4c17d5a0a9b932faf33..c6a4f50d5a52dd77bee9b35b09a2ce1d4d78be57 100644 --- a/examples/demo_mandelbrot.py +++ b/examples/demo_mandelbrot.py @@ -14,11 +14,9 @@ # http://www.daniweb.com/code/snippet216851.html# # with minor changes to move to numpy from the obsolete Numeric -import numpy as np import time -import numpy -import numpy.linalg as la +import numpy as np import pyopencl as cl @@ -30,6 +28,7 @@ import pyopencl as cl w = 512 h = 512 + def calc_fractal_opencl(q, maxiter): ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -63,14 +62,13 @@ def calc_fractal_opencl(q, maxiter): """).build() prg.mandelbrot(queue, output.shape, None, q_opencl, - output_opencl, np.uint16(maxiter)) + output_opencl, np.uint16(maxiter)) cl.enqueue_copy(queue, output, output_opencl).wait() return output - def calc_fractal_serial(q, maxiter): # calculate z using numpy # this routine unrolls calc_fractal_numpy as an intermediate @@ -79,26 +77,27 @@ def calc_fractal_serial(q, maxiter): z = np.zeros(q.shape, np.complex64) output = np.resize(np.array(0,), q.shape) for i in range(len(q)): - for iter in range(maxiter): + for it in range(maxiter): z[i] = z[i]*z[i] + q[i] if abs(z[i]) > 2.0: q[i] = 0+0j z[i] = 0+0j - output[i] = iter + output[i] = it return output + def calc_fractal_numpy(q, maxiter): # calculate z using numpy, this is the original # routine from vegaseat's URL output = np.resize(np.array(0,), q.shape) z = np.zeros(q.shape, np.complex64) - for iter in range(maxiter): + for it in range(maxiter): z = z*z + q done = np.greater(abs(z), 2.0) - q = np.where(done,0+0j, q) - z = np.where(done,0+0j, z) - output = np.where(done, iter, output) + q = np.where(done, 0+0j, q) + z = np.where(done, 0+0j, z) + output = np.where(done, it, output) return output # choose your calculation routine here by uncommenting one of the options @@ -107,10 +106,13 @@ calc_fractal = calc_fractal_opencl # calc_fractal = calc_fractal_numpy if __name__ == '__main__': - import Tkinter as tk + try: + import Tkinter as tk + except ImportError: + # Python 3 + import tkinter as tk from PIL import Image, ImageTk - class Mandelbrot(object): def __init__(self): # create window @@ -121,7 +123,6 @@ if __name__ == '__main__': # start event loop self.root.mainloop() - def draw(self, x1, x2, y1, y2, maxiter=30): # draw the Mandelbrot set, from numpy example xx = np.arange(x1, x2, (x2-x1)/w) @@ -135,8 +136,8 @@ if __name__ == '__main__': secs = end_main - start_main print("Main took", secs) - self.mandel = (output.reshape((h,w)) / - float(output.max()) * 255.).astype(np.uint8) + self.mandel = (output.reshape((h, w)) / + float(output.max()) * 255.).astype(np.uint8) def create_image(self): """" @@ -145,10 +146,8 @@ if __name__ == '__main__': # you can experiment with these x and y ranges self.draw(-2.13, 0.77, -1.3, 1.3) self.im = Image.fromarray(self.mandel) - self.im.putpalette(reduce( - lambda a,b: a+b, ((i,0,0) for i in range(255)) - )) - + self.im.putpalette([i for rgb in ((j, 0, 0) for j in range(255)) + for i in rgb]) def create_label(self): # put the image on a label widget @@ -158,4 +157,3 @@ if __name__ == '__main__': # test the class test = Mandelbrot() -