The Samovar

Fast fractals with Python and numpy

This will be of little interest to people who regularly read my blog, but might be of some interest to people who find their way here by the power of Google.

The standard way to compute fractals like the Mandelbrot set using Python and numpy is to use vectorisation and do the operations on a whole set of points. The problem is that this is slower than it needs to be because you keep doing computations on points that have already escaped. This can be avoided though, and the version below is about 3x faster than the standard way of doing it with numpy.

The trick is to create a new array at each iteration that stores only the points which haven’t yet escaped. The slight complication is that if you do this you need to keep track of the x, y coordinates of each of the points as well as the values of the iterate z. The same trick can be applied to many types of fractals and makes Python and numpy almost as good as C++ for mathematical exploration of fractals.

I’ve included the code below, both with and without explanatory comments. This 400×400 image below using 100 iterations took 1.1s to compute on my 1.8GHz laptop:


Uncommented version:

def mandel(n, m, itermax, xmin, xmax, ymin, ymax):
    ix, iy = mgrid[0:n, 0:m]
    x = linspace(xmin, xmax, n)[ix]
    y = linspace(ymin, ymax, m)[iy]
    c = x+complex(0,1)*y
    del x, y
    img = zeros(c.shape, dtype=int)
    ix.shape = n*m
    iy.shape = n*m
    c.shape = n*m
    z = copy(c)
    for i in xrange(itermax):
        if not len(z): break
        multiply(z, z, z)
        add(z, c, z)
        rem = abs(z)>2.0
        img[ix[rem], iy[rem]] = i+1
        rem = -rem
        z = z[rem]
        ix, iy = ix[rem], iy[rem]
        c = c[rem]
    return img

Commented version:

from numpy import *

def mandel(n, m, itermax, xmin, xmax, ymin, ymax):
    Fast mandelbrot computation using numpy.

    (n, m) are the output image dimensions
    itermax is the maximum number of iterations to do
    xmin, xmax, ymin, ymax specify the region of the
    set to compute.
    # The point of ix and iy is that they are 2D arrays
    # giving the x-coord and y-coord at each point in
    # the array. The reason for doing this will become
    # clear below...
    ix, iy = mgrid[0:n, 0:m]
    # Now x and y are the x-values and y-values at each
    # point in the array, linspace(start, end, n)
    # is an array of n linearly spaced points between
    # start and end, and we then index this array using
    # numpy fancy indexing. If A is an array and I is
    # an array of indices, then A[I] has the same shape
    # as I and at each place i in I has the value A[i].
    x = linspace(xmin, xmax, n)[ix]
    y = linspace(ymin, ymax, m)[iy]
    # c is the complex number with the given x, y coords
    c = x+complex(0,1)*y
    del x, y # save a bit of memory, we only need z
    # the output image coloured according to the number
    # of iterations it takes to get to the boundary
    # abs(z)>2
    img = zeros(c.shape, dtype=int)
    # Here is where the improvement over the standard
    # algorithm for drawing fractals in numpy comes in.
    # We flatten all the arrays ix, iy and c. This
    # flattening doesn't use any more memory because
    # we are just changing the shape of the array, the
    # data in memory stays the same. It also affects
    # each array in the same way, so that index i in
    # array c has x, y coords ix[i], iy[i]. The way the
    # algorithm works is that whenever abs(z)>2 we
    # remove the corresponding index from each of the
    # arrays ix, iy and c. Since we do the same thing
    # to each array, the correspondence between c and
    # the x, y coords stored in ix and iy is kept.
    ix.shape = n*m
    iy.shape = n*m
    c.shape = n*m
    # we iterate z->z^2+c with z starting at 0, but the
    # first iteration makes z=c so we just start there.
    # We need to copy c because otherwise the operation
    # z->z^2 will send c->c^2.
    z = copy(c)
    for i in xrange(itermax):
        if not len(z): break # all points have escaped
        # equivalent to z = z*z+c but quicker and uses
        # less memory
        multiply(z, z, z)
        add(z, c, z)
        # these are the points that have escaped
        rem = abs(z)>2.0
        # colour them with the iteration number, we
        # add one so that points which haven't
        # escaped have 0 as their iteration number,
        # this is why we keep the arrays ix and iy
        # because we need to know which point in img
        # to colour
        img[ix[rem], iy[rem]] = i+1
        # -rem is the array of points which haven't
        # escaped, in numpy -A for a boolean array A
        # is the NOT operation.
        rem = -rem
        # So we select out the points in
        # z, ix, iy and c which are still to be
        # iterated on in the next step
        z = z[rem]
        ix, iy = ix[rem], iy[rem]
        c = c[rem]
    return img

if __name__=='__main__':
    from pylab import *
    import time
    start = time.time()
    I = mandel(400, 400, 100, -2, .5, -1.25, 1.25)
    print 'Time taken:', time.time()-start
    I[I==0] = 101
    img = imshow(I.T, origin='lower left')
    img.write_png('mandel.png', noscale=True)

Countdown numbers game in Python

One of the things about being ill is that you have to spend a lot of time in bed with nothing much to do. Having watched the whole first series of the Sopranos, I had to find something else. So here’s the result. I revisited an old program I wrote many years ago to solve the Countdown numbers game.

In this game, you’re given six numbers between 1 and 100 and a target number between 100 and 999. You’re given 30 seconds to try to make the target using the six numbers and the operations plus, minus, times and divide.

I originally wrote a program to solve this many years ago (when I was about 14 I think), but the algorithm I used was pretty horrible. I worked out by hand all the possible arrangements of brackets you could have for six numbers, and then tried each operator and number in each appropriate slot. It worked, but it was ugly programming.

Recently I’ve been learning Python for an academic project, and so I thought I may as well try rewriting it in Python. I think the solution I’ve come up with is nicer than any of the solutions I’ve found on the internet (mostly written in Java or C), although having written it I found this paper which uses a very similar solution to mine (but in Haskell rather than Python).

Python programmers might get something from the minimal code below (all comments and docs stripped out), or you can take a look at the full source code here, including detailed comments and docs explaining the code and algorithm.

My ideal (as always with Python) was to write a program you could just look at and understand the source code without comments, but I don’t think I achieved that. I’d be interested if a more experienced Python programmer could do so. Let me know.

This version is incomplete, from the slower version, and is supposed to be understandable without explanations (takes about 40 seconds to find all solutions, too slow for Countdown):

def ValidExpressions(sources,operators=standard_operators,minimal_remaining_sources=0):
    for value, i in zip(sources,range(len(sources))):
        yield TerminalExpression(value=value, remaining_sources=sources[:i]+sources[i+1:])
    if len(sources)>=2+minimal_remaining_sources:
        for lhs in ValidExpressions(sources,operators,minimal_remaining_sources+1):
            for rhs in ValidExpressions(lhs.remaining_sources, operators, minimal_remaining_sources):
                for f in operators:
                    try: yield BranchedExpression(operator=f, lhs=lhs, rhs=rhs, remaining_sources=rhs.remaining_sources)
                    except InvalidExpressionError: pass

def TargetExpressions(target,sources,operators=standard_operators):
    for expression in ValidExpressions(sources,operators):
        if expression.value==target:
            yield expression

This version is actually complete, from the faster version which needs the comments to explain (takes about 15 seconds to run, good enough to win Countdown):

sub = lambda x,y: x-y
def add(x,y):
    if x<=y: return x+y
    raise ValueError
def mul(x,y):
    if x=2+minremsources:
        for e1, rs1, v1 in expressions(sources,ops,minremsources+1):
            for e2, rs2, v2 in expressions(rs1,ops,minremsources):
                for o in ops:
                    try: yield ([o,e1,e2],rs2,o(v1,v2))
                    except ValueError: pass

def findfirsttarget(target,sources,ops=standard_ops):
    for e,s,v in expressions(sources,ops):
        if v==target:
            return e
    return []