What is the numpy.expand_dims() function in NumPy?

Overview

The expand_dims() function in NumPy is used to expand the shape of an input array that is passed to it. This operation is done in such a way that when a new axis is inserted, it appears in the axis position of the resulting expanded array shape.

Axis in NumPy is defined for arrays that have more than one dimension. For example, a 2-D array has two corresponding axes: the axes running vertically downward across rows (this is axis 0) and the axes running horizontally across columns (this is axis 1).

Syntax

numpy.expand_dims(a, axis)
Syntax for the "expand_dims()" function

Parameter values

The expand_dims() function takes the following values:

  • a: This is the input array.
  • axis: This is the position in the resulting array where the new axis is to be positioned.

Return value

The expand_dims() function returns a view of the input array with an increased number of dimensions.

Code

import numpy as np
# creating an input array
a = np.array([1, 2, 3, 4])
# getting the dimension of a
print(a.shape)
# expanding the axis of a
b = np.expand_dims(a, axis=1)
# getting the dimension of the new array
print(b.shape)

Explanation

  • Line 1: We import the numpy module.
  • Line 3: We create an input array a using the array() function.
  • Line 6: We obtain the shape of a using the shape attribute.
  • Line 9: We expand the input array a using the expand_dims() function. Then, we assign the result to a variable b.
  • Line 12: We obtain and print the shape attribute of the newly expanded array b.

Free Resources