*You can find the code used in this article on colab*

Recently, it seems that everyone is talking about this new library: JAX.

But what is all the fuss about? Is it a new state-of-the-art machine learning interface? A new way to scale up your applications? From the official documentation:

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Maybe the best way to understand what it does is to list what one look at when choosing a new tool for **prototyping new ideas**:

**Easy to use**: hop right in and write some code in a matter of seconds;**Fast**: optimized and running quickly to test things easily;**Powerful**: provide some awesome features that are not seen elsewhere;**Understandable**: syntax is immediately clear and we know what our code is doing.

**JAX** meets all of these criteria. It’s a scientific python library that is more like a toolbox for prototyping ideas. It:

- Implements a NumPy API so it’s easy to use if you already know NumPy;
- Allows for JIT compilation of your code, automatic vectorization and parallelization of functions for batch computing, and it runs on CPU, GPUs, and TPUs transparently, so it’s fast;
- Provides an autograd feature to compute the gradient of pretty much any function

## JAX as a NumPy API

Let’s say you have been working on some code using NumPy as its scientific library backend, and you just want to differentiate this code, or you want it to run on the GPU.

import jax as jx

import numpy as np # Don't use this

import jax.numpy as jnp # Cool kids do this !

It’s as simple as that!

To be fair, your code has to respect a few constraints for it to work.

- You can only use pure functions : if you call your function twice it has to return the same result and you can’t do in-place updates of arrays.

array = jnp.zeros((2,2))

array[:, 0] = 1 # Won't work

array = jax.ops.index_update(array, jax.ops.index[:, 0], 1) # Better

- Random number generation is explicit and it uses a PRNG key

key = jx.random.PRNGKey(1)

x = jx.random.normal(key=key, shape=(1000,))

But apart from that, you can still use it to do the same operations as in NumPy.

def numpy_softmax(x: np.array) -> np.array:

exp_x = np.exp(x - np.max(x))

return exp_x / exp_x.sum(axis=0)

def jax_softmax(x: jnp.array) -> jnp.array:

exp_x = jnp.exp(x - jnp.max(x))

return exp_x / exp_x.sum(axis=0)

## Make it go faster

One of the cool features of JAX is its ability to compile your python code using XLA. It’s called Just-In-Time (JIT) compilation and it’s basically just caching some code that you use often so that it runs faster.

It’s easy to use, just decorate your function with** jax.jit** or call **jax.jit** on your function. We use **timeit** to time our execution.

%timeit jax_softmax(x)

%timeit jx.jit(jax_softmax)(x)

%timeit numpy_softmax(x)

`1000 loops, best of 5: 1.15 ms per loop`

1000 loops, best of 5: 296 µs per loop

1000 loops, best of 5: 646 µs per loop

And voila!

## Automatic differentiation

With machine learning or optimization in mind, you might want to be able to quickly compute the gradient of pretty much any function. Lucky you, that’s exactly what JAX is capable of! Use the **jax.grad** operator, you can even call it multiple times on the same function.

import matplotlib.pyplot as plt

cos_grad = jx.grad(jnp.cos)

cos_grad_grad = jx.grad(cos_grad)

cos_grad_grad_grad = jx.grad(cos_grad_grad)

x = jnp.linspace(-jnp.pi,jnp.pi, 1000)

plt.plot(x, jnp.cos(x))

plt.plot(x, [cos_grad(i) for i in x])

plt.plot(x, [cos_grad_grad(i) for i in x])

plt.plot(x, [cos_grad_grad_grad(i) for i in x])

## Vectorization

If you ran the previous code in the colab notebook, you may have found that it took an awful lot of time to execute. So why not use one more feature of JAX to make it faster? **jax.vmap **allows you to vectorize any function, making it useful when you want to batch operations for instance.

cos_grad_vec = jx.vmap(cos_grad)

%timeit [cos_grad(i) for i in x]

%timeit cos_grad_vec(x)

`1 loop, best of 3: 1.66 s per loop`

100 loops, best of 3: 2.41 ms per loop

It’s even better than this because you can specify on which argument of the function you want the vectorization to happen, so batching is as simple as ever.

def loss(x, constant):

return jnp.dot(x.T, x) + constant

batched_loss = jx.vmap(loss, in_axes=(0, None), out_axes=0)

## JAX: Where to go next?

We saw that JAX has many powerful features that make it the go-to tool when wanting to write scientific python that runs fast. It’s super flexible, fast and easy to learn.

What JAX was lacking for a long time is a rich ecosystem. With recent projects starting to mature, we are seeing more and more library and projects centered around JAX, such as:

- Haiku: a simple functional neural network library for JAX
- FLAX: An object oriented neural network library and ecosystem for JAX designed for flexibility
- TRAX: end-to-end library for deep learning that focuses on clear code and speed
- Optax: Optimization and gradient transformations

If it was released in 2018, JAX has gained traction in 2020 and its use is spreading fast. Google has been advocating for it for a while, and more and more projects are powered by it! Like the implementation of the **An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale** paper. So now is a good time to try it all for yourself!