What is reshape() in Tensorflow?

Tensorflow is an open-source Python framework used mainly to develop and train deep learning models. It takes multi-dimensional arrays called tensors as input, which then flow through a number of operations to develop a complete model.

The built-in reshape() function in Tensorflow allows tensors to be reshaped to match the input requirements and maintain consistency.

Syntax

tensorflow.reshape(tensor, shape, name=None)

Parameters

  1. tensor: the tensor that is to be reshaped
  2. shape: the shape of the ouput tensor
  3. name: operation name (optional)

The function returns a tensor.

Code

How to use the reshape() method

import tensorflow as tf
a=tf.constant([3,4,8,1,7,2]) #define a tensor of shape [6]
b=tf.constant([[7,8,9,10],[4,5,6,7]]) #define a tensor of shape [2,4]
print ("Shape of tensor a:", a.get_shape())
print ("Shape of tensor b:", b.get_shape())
aNew=tf.reshape(a, [2,3])
bNew=tf.reshape(b, [4,2]) #store the output in new tensors
print ("Tensor a after reshaping:", aNew.get_shape())
print ("Tensor b after reshaping:", bNew.get_shape()) #print the shape of output tensors

If [-1] is passed as the shape argument, the function returns a one-dimensional tensor. This is called flattening a tensor.

Flattening a tensor

import tensorflow as tf
a=tf.constant([[3,4],[7,8],[1,2],[9,10],[11,12]]) #define a tensor of shape [5,2]
b=tf.constant([[7,8,9,10],[4,5,6,7]]) #define a tensor of shape [2,4]
print ("Shape of tensor a:", a.get_shape())
print ("Shape of tensor b:", b.get_shape())
aNew=tf.reshape(a,[-1])
bNew=tf.reshape(b,[-1]) #flatten both the arrays
print ("Tensor a after reshaping:", aNew.get_shape())
print ("Tensor b after reshaping:", bNew.get_shape()) #print the shape of output tensors

Free Resources