It has been a while I haven’t looked into the current development of Python compilers of which improves the running time of critical paths, e.g. in quantitative finance / scientific research, in the application. There are much more libraries tackling the problem, either by bridging the original CPython C extension library, or by just-in-time (JIT) implementation. I will first go through with three prevailing libraries nowadays.
Cython, Numba and JAX
First it comes with Cython, the backbone of Pandas, to give easy access to C extension with Python-alike syntax. Then Numba takes advantage of modern LLVM development to compile every CPython bytecode operator on the fly. Finally JAX from Google joined the game in 2018. It first focused on the gradient / differentiation in deep learning, and now there are more industrial use cases to use its JIT compilation functionality.
Benchmark function
I picked a function to compute the weighted sum of numerical integral. The integral is sliced in arbitrary and the slices come as an input parameter.
I first wrote the function in Python, and supposed it is the most straight-forward approach which developers / researchers employs. The numpy sum method actually runs in C extension underneath. Cheating! But let’s see how it goes in the comparison afterward.
Cython
For Cython, same as my previous experience, first I had to convert the native Python function to Cython syntax, with defining the function as “cpdef” (Cython and Python accessible). Secondly I defined all the types of the function parameters, variables and outputs. Finally I compiled and prayed for it producing a great result. In general, I reckoned that I had to abandon all the numpy built-in functions (as simple as sum), and rewrote the algorithm (e.g. by for loop).
The Python run spends about 2 ms per loop, while, after a long endeavour, Cython runs it in microsecond level. Not bad.
Numba
The life in Numba is much enjoyable. I can directly compile the native Python function with its helper function. I didn’t bother too much about turning the parameters in Numba.
And the performance is amazing, as good as the Cython implementation.
JAX
Finally, there are a few considerations in running JAX benchmark. First it is needed to convert the numpy array into the JAX array, and the convention overhead is excluded from the benchmark result. Second, the JAX operation dispatches the result in asynchronously, so it requires to block and wait for the complete result to arrive. Thirdly, the result in the example is a single float but stored in JAX array. Though the convention to Python native type does not incur much overhead, it is excluded in the benchmark as well.
Finally, there are a few considerations in running JAX benchmark. First it is needed to convert the numpy array into the JAX array, and the convention overhead is excluded from the benchmark result. Second, the JAX operation dispatches the result in asynchronously, so it requires to block and wait for the complete result to arrive. Thirdly, the result in the example is a single float but stored in JAX array. Though the convention to Python native type does not incur much overhead, it is excluded in the benchmark as well.
In my example, the Python implementation slices the array in dynamic, i.e. the slicing index is originated from the user input and not determined in compilation time. JAX does not support compilation in dynamic array slicing at the moment. To work around, a partial function is constructed by passing the slicing array (t_data) into the compilation. The compilation takes a while (around 2 mins when the size of the slicing array is 1000), and I suppose the compiled result is large.
The performance time improves a lot, though it is not as fast as the Cython and Numba ones.
Which one is the best?
Does it conclude that Cython is still the best library to compile Python code to C extension level? I would like to answer the question in a few levels.
1. Runtime
The first level is on the machine performance time. Actually it is an illusion that Cython gives the best performance from my example. If I rewrite the native Python function like the Cython implementation (use for loop to iterate rather than numpy.sum), the Numba compiled function produces the best performance.
Also, Numba has spent a vast amount of efforts to support numpy features. I believe in a lot of cases, especially those involving numpy arrays, Numba outperforms Cython in a significant margin.
2. Development effort
The second level is the development time. Numba has a great benefit and assurance to users to keep the existing Python code and then attempt to give improvements in runtime. The investment of development time is minimal in Numba, while the reward can be huge. This feature has outperformed Cython and JAX.
3. Long term maintenance
The third level is the long term adaption to revolutionary technology. Numpy and Cython can only support CPU runtime, while Numba and JAX have already adapted GPU. JAX even supports TPU, a critical advantage to machine learning practitioners. The asynchronous dispatch feature in JAX echoes with the modern machine learning development to optimise the CPU and GPU usage in clusters.
Reference
The benchmark notebook can be accessible in Google Colab. Open for comments and ideas!
(Original post was in Linkedin)