JAX (Just After eXecution) is a recent machine/deep learning library developed by DeepMind.
Unlike Tensorflow, JAX is not an official Google product and is used for research purposes. The use of JAX is growing among the research community due to some really cool features. Additionally, the need to learn new syntax to use JAX is reduced by its NumPy-like syntax.
JAX is basically a Just-In-Time (JIT) compiler focused on harnessing the maximum number of
While JAX requires a course of its own, we will go through some of its features in this shot.
import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)print(a)b = np.zeros((10,20))print(b)print("---And now JAX versions-----")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(jnp_a)print(jnp_b)
Although JAX has the same syntax as NumPy, it differs in some aspects:
The examples below will further illustrate these differences.
import jaximport jax.numpy as jnpimport numpy as npa = np.linspace(0.0,2.0,10)b = np.zeros((10,20))print(type(a))print("---JAX array---")jnp_a = jnp.linspace(0.0,2.0,10)jnp_b = jnp.zeros((10,20))print(type(jnp_a))
JAX allows us to perform JIT compilation. All you have to do is call the function within jit()
or decorate it with @jit
, as shown below.
import jaximport jax.numpy as jnpimport numpy as npfrom jax import jitdef Square(x):return x*x@jitdef JSquare(x):return x*xprint(Square(4.1))print(JSquare(4.1))
The output is the same, but you will notice the difference (due to tracer objects preempting the result) if you use some selection structure in the code above.
JAX has support for both Autograd and Autovectorization. The example below shows a glimpse of Autograd.
from jax import graddef Y(a):return 3*a*a-a+1dy = (grad(Y))dy2 = grad(grad(Y))dy3 = grad(grad(grad(Y)))a = 2.0print(dy(a))print(dy2(a))print(dy3(a))
For more details, please check the relevant course.
Free Resources