The expand_dims()
function in NumPy is used to expand the shape of an input array that is passed to it. This operation is done in such a way that when a new axis is inserted, it appears in the axis
position of the resulting expanded array shape.
Axis in NumPy is defined for arrays that have more than one dimension. For example, a2-D
array has two corresponding axes: the axes running vertically downward across rows (this is axis0
) and the axes running horizontally across columns (this is axis1
).
numpy.expand_dims(a, axis)
The expand_dims()
function takes the following values:
a
: This is the input array.axis
: This is the position in the resulting array where the new axis is to be positioned.The expand_dims()
function returns a view of the input array with an increased number of dimensions.
import numpy as np# creating an input arraya = np.array([1, 2, 3, 4])# getting the dimension of aprint(a.shape)# expanding the axis of ab = np.expand_dims(a, axis=1)# getting the dimension of the new arrayprint(b.shape)
numpy
module.a
using the array()
function.a
using the shape
attribute.a
using the expand_dims()
function. Then, we assign the result to a variable b
.shape
attribute of the newly expanded array b
.