What is JAX?

Overview

JAX (Just After eXecution) is a recent machine/deep learning library developed by DeepMind.

Unlike Tensorflow, JAX is not an official Google product and is used for research purposes. The use of JAX is growing among the research community due to some really cool features. Additionally, the need to learn new syntax to use JAX is reduced by its NumPy-like syntax.

Features of JAX

JAX is basically a Just-In-Time (JIT) compiler focused on harnessing the maximum number of FLOPsFloating-Point Operations Per Second to generate optimized code while using the simplicity of pure Python. Some of the salient features of JAX are:

  • Just-in-Time (JIT) compilation.
  • Enables NumPy code on not only CPU but GPU and TPU as well.
  • Automatic differentiation of NumPy and native Python code.
  • Automatic vectorization.
  • Express and compose transformations of numerical programs.
  • Advanced (pseudo) random number generation.
  • More options for control flow.

While JAX requires a course of its own, we will go through some of its features in this shot.

import jax
import jax.numpy as jnp
import numpy as np
a = np.linspace(0.0,2.0,10)
print(a)
b = np.zeros((10,20))
print(b)
print("---And now JAX versions-----")
jnp_a = jnp.linspace(0.0,2.0,10)
jnp_b = jnp.zeros((10,20))
print(jnp_a)
print(jnp_b)

Although JAX has the same syntax as NumPy, it differs in some aspects:

  • Support of GPU/TPU (hence the warning).
  • Different datatypes.
  • Some restrictions.

The examples below will further illustrate these differences.

import jax
import jax.numpy as jnp
import numpy as np
a = np.linspace(0.0,2.0,10)
b = np.zeros((10,20))
print(type(a))
print("---JAX array---")
jnp_a = jnp.linspace(0.0,2.0,10)
jnp_b = jnp.zeros((10,20))
print(type(jnp_a))

JIT compilation

JAX allows us to perform JIT compilation. All you have to do is call the function within jit() or decorate it with @jit, as shown below.

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
def Square(x):
return x*x
@jit
def JSquare(x):
return x*x
print(Square(4.1))
print(JSquare(4.1))

The output is the same, but you will notice the difference (due to tracer objects preempting the result) if you use some selection structure in the code above.

Autograd

JAX has support for both Autograd and Autovectorization. The example below shows a glimpse of Autograd.

from jax import grad
def Y(a):
return 3*a*a-a+1
dy = (grad(Y))
dy2 = grad(grad(Y))
dy3 = grad(grad(grad(Y)))
a = 2.0
print(dy(a))
print(dy2(a))
print(dy3(a))

For more details, please check the relevant course.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved