Cythonize!

In this post I take Python code to do a task (draw fractals), write it in the spirit of Python (a pleasure to write, read and maintain), use profiling to figure out the bottle necks and then speed up the most time consuming function using cython.

Python is great for writing projects of all sizes. Python code is easy to write, read and maintain (when well written) and there are many packages people have written to cover almost any use case you might have. Python, however, is slow. You don’t really feel this when you prototype things, but when you throw real data at the Python code, you wonder if things could be faster.

When I first took up Python I learned about boost python, and about swig and similar ways of wrapping C code to call from Python. The boiler plate involved dampened my enthusiasm. Such boiler plate (and the slow compile/debug cycle) was the major reason why I moved away from C/C++, which I otherwise quite liked.

Then, I recently discovered cython. cython is pretty interesting. It is a superset of the Python language with type declaration directives. It turns out that dynamic typing is a major reason why Python is “slow”. cython lets us static type variables of our choice, thereby getting rid of some of the overhead.

Consider a program to compute a representation of the mandelbrot fractal

def f(x, y, max_iter=1000):
  """
  z <- z^2 + c
  """
  c = x + y*1j
  z = (0 +0j)
  for n in range(max_iter):
    z = z**2 + c
    if abs(z) > 2:
      return 1.0 - float(n)/max_iter
  else:
    return 0.0

def mandelbrot(x0=-2.0, y0=-1.5, x1=1.0, y1=1.5, grid=50, max_iter=1000):
  dx, dy = float(x1 - x0)/grid, float(y1 - y0)/grid
  return [[f(x0 + n * dx, y0 + m * dy, max_iter) for n in range(grid)] for m in range(grid)]

This can be called and plotted as follows:

import pylab
pylab.imshow(mandelbrot(grid=500), cmap=pylab.cm.gray)

500px

Once you run the code a few times you will figure out that things get slow real fast because this is an O(n^2) algorithm – since the number of computations increases as the square of the size of the computation grid.

Now, since this is a reduced example, we can tell from simple inspection that the bottle neck function here is f. In a real world case you would make sure, perhaps using Python’s useful cProfile module where the bottleneck really is.

An easy way to do that is to import the code we just wrote and then run it using the profiler.

import cProfile
cProfile.run('mandelbrot(grid=100)', 'stats')

This will profile the run of the function and save the statistics to the file ‘stats’. We then use the stats module to interpret and print these statistics. We chose to order them by total time spent in a function, since we are trying to figure out where the computer spends most of its effort.

import pstats
p = pstats.Stats('stats')
p.sort_stats('tottime').print_stats(10)

Leading to an analysis that looks a bit like:

In [87]: run mandelbrot.py
Sat Jul 19 19:41:57 2014    stats

         1757763 function calls in 0.755 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10000    0.557    0.000    0.748    0.000 .../mandelbrot.py:2(f)
  1737658    0.138    0.000    0.138    0.000 {abs}
    10101    0.053    0.000    0.053    0.000 {range}
        1    0.007    0.007    0.755    0.755 .../mandelbrot.py:15(mandelbrot)
        1    0.000    0.000    0.000    0.000 {numpy.core.multiarray.zeros}
        1    0.000    0.000    0.755    0.755 :1()
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

As we guessed, the inner loop function f is the main CPU hog. In the past I would have stopped there. Sure, I could rewrite the inner loop in C – which would be satisfactorily efficient – but the headaches of making it work with Python! But cython lets us get some speedup with a minimum of fuss.

For a full introduction to cython, you should follow the tutorial, but here is one approach to using Cython to speed up the code we just wrote.

First, we simply transfer the bottleneck function to its own cython module – which is indicated with an extension of .pyx. Since Cython is a superset of Python for this first pass we simply leave f as it is, creating a file

mandel_f.pyx
def f(x, y, max_iter=1000):
  """
  z <- z^2 + c
  """
  c = x + y*1j
  z = (0 +0j)
  for n in range(max_iter):
    z = z**2 + c
    if abs(z) > 2:
      return 1.0 - float(n)/max_iter
  else:
    return 0.0

We then modify our original script to use the cythonized function:

import pyximport
pyximport.install(pyimport = True)
from mandel_f import f

def mandelbrot(x0=-2.0, y0=-1.5, x1=1.0, y1=1.5, grid=50, max_iter=1000):
  dx, dy = float(x1 - x0)/grid, float(y1 - y0)/grid
  return [[f(x0 + n * dx, y0 + m * dy, max_iter) for n in range(grid)] for m in range(grid)]

There are several things happening here all at once and there are several other ways to do this. Whatever code you put in the .pyx file is taken to be cython code. Cython is a superset of Python, so this time we left the python code as is and the cython system does just fine with it. Cython needs to first translate the (.pyx) file into C and then compile it to a library with appropriate hooks to python and then we need to import this compiled library into our code.

We could do this manually, but if our code is simple we can use the pyximport module.

Executing pyximport.install(pyimport = True) causes an automatic compilation of any imported cython files hiding the whole process from us. Therefore, the command from mandel_f import f causes mandel_f.pyx to be compiled and imported.

Profiling this code yields

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10000    0.348    0.000    0.348    0.000 {mandel_f.f}
    ...

Which is a small speedup, but without doing anything.

Suppose we start to put in some type information into the code

def f(float x, float y, int max_iter=1000):
  """
  z <- z^2 + c
  """
  c = x + y*1j
  z = (0 +0j)
  cdef int n
  for n in range(max_iter):
    z = z**2 + c
    if abs(z) > 2:
      return 1.0 - float(n)/max_iter
  else:
    return 0.0

The function signature now resembles a C signature and in the body we have declared n to be an int. Profiling yeilds

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10000    0.243    0.000    0.243    0.000 {mandel_f2.f}

Which is a modest 2x speedup from the original pure python code.

Now we will do something that is a bit messy. See how we left the complex numbers alone? We could try and cythonize that too.

cdef extern from "complexobject.h":

  struct Py_complex:
    double real
    double imag

  ctypedef class __builtin__.complex [object PyComplexObject]:
    cdef Py_complex cval

def f(float x, float y, int max_iter=1000):
  """
  z <- z^2 + c
  """
  cdef complex c = x + y*1j
  cdef complex z = (0 +0j)
  cdef int n
  for n in range(max_iter):
    z = z**2 + c
    if abs(z) > 2:
      return 1.0 - float(n)/max_iter
  else:
    return 0.0

Profiling this we get

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10000    0.186    0.000    0.186    0.000 {mandel_f3.f}

which is around a 3x speedup.

A proper comparison, of course, would be in rewriting the function in C and then profiling that wrapper.

I would like to do one more thing before we leave. Function calls in Python have a lot of overhead, and this function is heavily called. Let’s see how much we gain by vectorizing our function. Vectorizing allows us to maintain the neatness of the function, but also reduces the function call overhead.

def f(x_list, y_list, max_iter=1000):
  """
  z <- z^2 + c
  """
  row = []
  for y in y_list:
    col = []
    for x in x_list:
      c = x + y*1j
      z = (0 +0j)
      for n in range(max_iter):
        z = z**2 + c
        if abs(z) > 2:
          rv = 1.0 - float(n)/max_iter
          break
      else:
        rv = 0.0
      col.append(rv)
    row.append(col)
  return row

def mandelbrot(x0=-2.0, y0=-1.5, x1=1.0, y1=1.5, grid=50, max_iter=1000):
  dx, dy = float(x1 - x0)/grid, float(y1 - y0)/grid
  return f([x0 + n * dx for n in range(grid)] , [y0 + m * dy for m in range(grid)] , max_iter)

def quick_plot(grid=100):
  import pylab
  pylab.imshow(mandelbrot(grid=grid), cmap=pylab.cm.gray)
  pylab.show()
  ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.501    0.501    0.670    0.670 .../mandelbrot_vec.py:1(f)

After applying all the cythonizations we learned:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.188    0.188    0.188    0.188 {mandel_f_vec.f}

Not THAT impressive, huh? I was pretty surprised as I had hope the “vectorized” form would speed things up a bunch by getting rid of function call overhead.

Advertisements

One Reply to “Cythonize!”

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s