Skip to content
Snippets Groups Projects
Commit c234875a authored by Christoph Gohlke's avatar Christoph Gohlke
Browse files

Clean up demo_mandelbrot.py

Restore Python 3 compatibility; Apply PEP8 and PyFlakes; Remove unused imports.
parent bdf66f86
Branches
Tags
No related merge requests found
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
# http://www.daniweb.com/code/snippet216851.html# # http://www.daniweb.com/code/snippet216851.html#
# with minor changes to move to numpy from the obsolete Numeric # with minor changes to move to numpy from the obsolete Numeric
import numpy as np
import time import time
import numpy import numpy as np
import numpy.linalg as la
import pyopencl as cl import pyopencl as cl
...@@ -30,6 +28,7 @@ import pyopencl as cl ...@@ -30,6 +28,7 @@ import pyopencl as cl
w = 512 w = 512
h = 512 h = 512
def calc_fractal_opencl(q, maxiter): def calc_fractal_opencl(q, maxiter):
ctx = cl.create_some_context() ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx) queue = cl.CommandQueue(ctx)
...@@ -63,14 +62,13 @@ def calc_fractal_opencl(q, maxiter): ...@@ -63,14 +62,13 @@ def calc_fractal_opencl(q, maxiter):
""").build() """).build()
prg.mandelbrot(queue, output.shape, None, q_opencl, prg.mandelbrot(queue, output.shape, None, q_opencl,
output_opencl, np.uint16(maxiter)) output_opencl, np.uint16(maxiter))
cl.enqueue_copy(queue, output, output_opencl).wait() cl.enqueue_copy(queue, output, output_opencl).wait()
return output return output
def calc_fractal_serial(q, maxiter): def calc_fractal_serial(q, maxiter):
# calculate z using numpy # calculate z using numpy
# this routine unrolls calc_fractal_numpy as an intermediate # this routine unrolls calc_fractal_numpy as an intermediate
...@@ -79,26 +77,27 @@ def calc_fractal_serial(q, maxiter): ...@@ -79,26 +77,27 @@ def calc_fractal_serial(q, maxiter):
z = np.zeros(q.shape, np.complex64) z = np.zeros(q.shape, np.complex64)
output = np.resize(np.array(0,), q.shape) output = np.resize(np.array(0,), q.shape)
for i in range(len(q)): for i in range(len(q)):
for iter in range(maxiter): for it in range(maxiter):
z[i] = z[i]*z[i] + q[i] z[i] = z[i]*z[i] + q[i]
if abs(z[i]) > 2.0: if abs(z[i]) > 2.0:
q[i] = 0+0j q[i] = 0+0j
z[i] = 0+0j z[i] = 0+0j
output[i] = iter output[i] = it
return output return output
def calc_fractal_numpy(q, maxiter): def calc_fractal_numpy(q, maxiter):
# calculate z using numpy, this is the original # calculate z using numpy, this is the original
# routine from vegaseat's URL # routine from vegaseat's URL
output = np.resize(np.array(0,), q.shape) output = np.resize(np.array(0,), q.shape)
z = np.zeros(q.shape, np.complex64) z = np.zeros(q.shape, np.complex64)
for iter in range(maxiter): for it in range(maxiter):
z = z*z + q z = z*z + q
done = np.greater(abs(z), 2.0) done = np.greater(abs(z), 2.0)
q = np.where(done,0+0j, q) q = np.where(done, 0+0j, q)
z = np.where(done,0+0j, z) z = np.where(done, 0+0j, z)
output = np.where(done, iter, output) output = np.where(done, it, output)
return output return output
# choose your calculation routine here by uncommenting one of the options # choose your calculation routine here by uncommenting one of the options
...@@ -107,10 +106,13 @@ calc_fractal = calc_fractal_opencl ...@@ -107,10 +106,13 @@ calc_fractal = calc_fractal_opencl
# calc_fractal = calc_fractal_numpy # calc_fractal = calc_fractal_numpy
if __name__ == '__main__': if __name__ == '__main__':
import Tkinter as tk try:
import Tkinter as tk
except ImportError:
# Python 3
import tkinter as tk
from PIL import Image, ImageTk from PIL import Image, ImageTk
class Mandelbrot(object): class Mandelbrot(object):
def __init__(self): def __init__(self):
# create window # create window
...@@ -121,7 +123,6 @@ if __name__ == '__main__': ...@@ -121,7 +123,6 @@ if __name__ == '__main__':
# start event loop # start event loop
self.root.mainloop() self.root.mainloop()
def draw(self, x1, x2, y1, y2, maxiter=30): def draw(self, x1, x2, y1, y2, maxiter=30):
# draw the Mandelbrot set, from numpy example # draw the Mandelbrot set, from numpy example
xx = np.arange(x1, x2, (x2-x1)/w) xx = np.arange(x1, x2, (x2-x1)/w)
...@@ -135,8 +136,8 @@ if __name__ == '__main__': ...@@ -135,8 +136,8 @@ if __name__ == '__main__':
secs = end_main - start_main secs = end_main - start_main
print("Main took", secs) print("Main took", secs)
self.mandel = (output.reshape((h,w)) / self.mandel = (output.reshape((h, w)) /
float(output.max()) * 255.).astype(np.uint8) float(output.max()) * 255.).astype(np.uint8)
def create_image(self): def create_image(self):
"""" """"
...@@ -145,10 +146,8 @@ if __name__ == '__main__': ...@@ -145,10 +146,8 @@ if __name__ == '__main__':
# you can experiment with these x and y ranges # you can experiment with these x and y ranges
self.draw(-2.13, 0.77, -1.3, 1.3) self.draw(-2.13, 0.77, -1.3, 1.3)
self.im = Image.fromarray(self.mandel) self.im = Image.fromarray(self.mandel)
self.im.putpalette(reduce( self.im.putpalette([i for rgb in ((j, 0, 0) for j in range(255))
lambda a,b: a+b, ((i,0,0) for i in range(255)) for i in rgb])
))
def create_label(self): def create_label(self):
# put the image on a label widget # put the image on a label widget
...@@ -158,4 +157,3 @@ if __name__ == '__main__': ...@@ -158,4 +157,3 @@ if __name__ == '__main__':
# test the class # test the class
test = Mandelbrot() test = Mandelbrot()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment