What is vectorizing map (vmap) in JAX?

The vectorization map (vmap) is a function in the JAX library that maps a function over one or more input arguments. It is a powerful tool to speed up the execution of a function, especially when the function is called many times with the same input arguments.

Syntax

The syntax of the vmap function is as follows:

jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)

Parameters

The vmap function has the following three parameters:

  1. fun is the function name to be mapped.
  2. in_axes specifies which axes of the input data should be mapped over by the function.
  3. out_axes indicates where the mapped axes should appear in the function’s output.

The other parameters are optional and are listed below:

  1. axis_name is a unique name for the mapped axes used for parallel operations.
  2. axis_size is an integer indicating the axis size to be mapped.
  3. spmd_axis_name is an optional name for the axis used for parallel execution in single program multiple data (SPMD) parallelism.

Return object

The vmap function returns a modified version of the fun function. It can apply the original function to multiple elements of an array in a batched or vectorized manner.

How to use vmap

To use the vmap function, we must define the function we want to vectorize. Then, we’ll call vmap with the function as the first argument and the axes we want to vectorize over as the second argument.

Example 1

Consider the following playground:

import jax
import jax.numpy as jnp
def add(x, y):
return x + y
vmapped_add = jax.vmap(add)
batch_x = jnp.array([1,2,3,4,5])
batch_y = jnp.array([2,3,4,5,6])
vmapped_result = vmapped_add(batch_x, batch_y)
print("Output of vmap function:", vmapped_result)
  • Lines 4–6:, We define the add function, which performs the simple addition operation on two variables.
  • Line 7: To perform this operation on two batches of numbers, we use the vmap function and the sum function as the vmaped_add function.
  • Line 12: We call the vmaped_add function with two batches of data. We can see the result of element-wise addition.

Example 2

The 00 in value of in_axes represents the mapping on the same axis while the 11 represents the mapping of the xx-axis onto the yy-axis. In matrix terms, the 11 value in in_axes will take the transpose of the matrix. For example, if the function has nn input of parameters, the in_axes will be a tuple of nn values. In the following example, the add function has two parameters, so the value of in_axes will include two values of 00 and 11. The version of vmap and the sum function with the (0,1)(0,1) value of the in_axes parameter will take the transpose of the second input parameter and return the sum.

Let’s see how the in_axes parameter affects the functionality of the vmap function in the following playground:

import jax
import jax.numpy as jnp
def add(x, y):
return x + y
vmapped_add1 = jax.vmap(add, in_axes=(0,0))
vmapped_add2 = jax.vmap(add, in_axes=(1,0))
vmapped_add3 = jax.vmap(add, in_axes=(0,1))
vmapped_add4 = jax.vmap(add, in_axes=(1,1))
batch_x = jnp.array([[1,2],
[3,4]])
batch_y = jnp.array([[5,6],
[7,8]])
vmapped_result1 = vmapped_add1(batch_x, batch_y)
vmapped_result2 = vmapped_add2(batch_x, batch_y)
vmapped_result3 = vmapped_add3(batch_x, batch_y)
vmapped_result4 = vmapped_add4(batch_x, batch_y)
print("Output with in_axes = (0,0)")
print(vmapped_result1)
print("Output with in_axes = (1,0)")
print(vmapped_result2)
print("Output with in_axes = (0,1)")
print(vmapped_result3)
print("Output with in_axes = (1,1)")
print(vmapped_result4)
  • Lines 7–10: We created different vmap functions with different values of in_axes.
  • Lines 18–21: We call all the vmap functions with the same input.

Example 3

The functionality of the out_axes is similar to that of in_axis but it performs the mapping on the output.

Let’s see the following example to understand that:

import jax
import jax.numpy as jnp
def add(x, y):
return x + y
vmapped_add1 = jax.vmap(add, in_axes=(1,0), out_axes=0)
vmapped_add2 = jax.vmap(add, in_axes=(1,0), out_axes=1)
batch_x = jnp.array([[1,2,3],
[4,5,6]])
batch_y = jnp.array([[5,6],
[7,8],
[9,0]])
vmapped_result1 = vmapped_add1(batch_x, batch_y)
vmapped_result2 = vmapped_add2(batch_x, batch_y)
print("Output with out_axes = 0")
print(vmapped_result1)
print("Output with out_axes = 1")
print(vmapped_result2)
  • Lines 7–8: We created two different vmap functions with the value of 00 and 11 of out_axes.
  • Lines 11–15: We created two different batches of input values with different dimensions.
  • Lines 17–18: We called all the vmap functions with the same inputs to see the difference in the output. We can see the mapping of axes in the output.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved