The vectorization map (vmap
) is a function in the JAX library that maps a function over one or more input arguments. It is a powerful tool to speed up the execution of a function, especially when the function is called many times with the same input arguments.
The syntax of the vmap
function is as follows:
jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)
The vmap
function has the following three parameters:
fun
is the function name to be mapped.in_axes
specifies which axes of the input data should be mapped over by the function.out_axes
indicates where the mapped axes should appear in the function’s output.The other parameters are optional and are listed below:
axis_name
is a unique name for the mapped axes used for parallel operations.axis_size
is an integer indicating the axis size to be mapped.spmd_axis_name
is an optional name for the axis used for parallel execution in single program multiple data (SPMD) parallelism.The vmap
function returns a modified version of the fun
function. It can apply the original function to multiple elements of an array in a batched or vectorized manner.
vmap
To use the vmap
function, we must define the function we want to vectorize.
Then, we’ll call vmap
with the function as the first argument and the axes we want to vectorize over as the second argument.
Consider the following playground:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add = jax.vmap(add)batch_x = jnp.array([1,2,3,4,5])batch_y = jnp.array([2,3,4,5,6])vmapped_result = vmapped_add(batch_x, batch_y)print("Output of vmap function:", vmapped_result)
add
function, which performs the simple addition operation on two variables.vmap
function and the sum
function as the vmaped_add
function.vmaped_add
function with two batches of data. We can see the result of element-wise addition.The in value of in_axes
represents the mapping on the same axis while the represents the mapping of the -axis onto the -axis. In matrix terms, the value in in_axes
will take the transpose of the matrix. For example, if the function has input of parameters, the in_axes
will be a tuple of values. In the following example, the add
function has two parameters, so the value of in_axes
will include two values of and . The version of vmap
and the sum
function with the value of the in_axes
parameter will take the transpose of the second input parameter and return the sum.
Let’s see how the in_axes
parameter affects the functionality of the vmap
function in the following playground:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add1 = jax.vmap(add, in_axes=(0,0))vmapped_add2 = jax.vmap(add, in_axes=(1,0))vmapped_add3 = jax.vmap(add, in_axes=(0,1))vmapped_add4 = jax.vmap(add, in_axes=(1,1))batch_x = jnp.array([[1,2],[3,4]])batch_y = jnp.array([[5,6],[7,8]])vmapped_result1 = vmapped_add1(batch_x, batch_y)vmapped_result2 = vmapped_add2(batch_x, batch_y)vmapped_result3 = vmapped_add3(batch_x, batch_y)vmapped_result4 = vmapped_add4(batch_x, batch_y)print("Output with in_axes = (0,0)")print(vmapped_result1)print("Output with in_axes = (1,0)")print(vmapped_result2)print("Output with in_axes = (0,1)")print(vmapped_result3)print("Output with in_axes = (1,1)")print(vmapped_result4)
vmap
functions with different values of in_axes
.vmap
functions with the same input.The functionality of the out_axes
is similar to that of in_axis
but it performs the mapping on the output.
Let’s see the following example to understand that:
import jaximport jax.numpy as jnpdef add(x, y):return x + yvmapped_add1 = jax.vmap(add, in_axes=(1,0), out_axes=0)vmapped_add2 = jax.vmap(add, in_axes=(1,0), out_axes=1)batch_x = jnp.array([[1,2,3],[4,5,6]])batch_y = jnp.array([[5,6],[7,8],[9,0]])vmapped_result1 = vmapped_add1(batch_x, batch_y)vmapped_result2 = vmapped_add2(batch_x, batch_y)print("Output with out_axes = 0")print(vmapped_result1)print("Output with out_axes = 1")print(vmapped_result2)
vmap
functions with the value of and of out_axes
.vmap
functions with the same inputs to see the difference in the output. We can see the mapping of axes in the output.Free Resources