What is the use of fit_transform from sklearn in Python?

The fit_transform method

The fit_transform method from the sklearn.preprocessing class is used to preprocess the data for model training. The fit() method calculates the various required parameters, and the transform() method applies the calculated parameters to standardize the data. In the case of the StandardScaler(), the mean and the standard deviation are calculated and the data is scaled and centered to have a mean of 0 and a standard deviation of 1. Calling fit_transform() once is equivalent to calling fit() and then transform() on the data.

Preprocessing

Preprocessing is necessary before training any model on a given dataset. For example during training neural networks via gradient descent. If one dimension is much greater than the other, we would have the problem of over-shooting the minima, and our model will fail to converge.

Overshooting the minima
Overshooting the minima

In comparison, gradient descent converges better when the dimensions have been scaled to have a standard deviation of 1.

Finding minima with a standard deviation of 1
Finding minima with a standard deviation of 1

The standardization formula is given by xscaled=xmeanstd.dev.x _{scaled} = \frac{x-mean}{std. dev.}

When fit_transform() method is called, the standard deviation (std.dev.std.dev.) and the mean of the data is computed, and the values are scaled as per the equation above.

Note: We only pre-process the training data with fit_transform(). The testing dataset is preprocessed by transform() as its distribution is unknown, and is assumed to be the same as that of the training dataset.

Code example

In the following playground, we'll see how to use the fit_transform() method, and how it is equivalent to using fit() and transform() together.

# import relavant libraries
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# load the iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
# create train and test splits
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
# print the original dataset
print('Original Train: \n', X_train[0:5])
print('Original Test: \n', X_test[0:5])
print("-"*50)
print('Using transform()')
# use fit() and transform separately
std_slc = StandardScaler()
std_slc.fit(X_train)
X_train_std = std_slc.transform(X_train)
X_test_std = std_slc.transform(X_test)
print('Transformed Train: \n', X_train_std[0:5])
print('Transformed Test: \n', X_test_std[0:5])
print("-"*50)
print('Using fit_transform()')
# use fit and transform in a single function call
std_slc2 = StandardScaler()
X_train_std2 = std_slc2.fit_transform(X_train)
X_test_std2 = std_slc2.transform(X_test)
print('Transformed Train: \n', X_train_std[0:5])
print('Transformed Test: \n', X_test_std[0:5])
# verify that using fit_transform() equates to using fit() and transform() together
if (X_train_std2 == X_train_std).all() and (X_test_std2 == X_test_std).all():
print ('both are equivalent')

Explanation

  • Line 2–4: We import the relevant libraries.

  • Line 7–9: We load the dataset.

  • Line 12: We create train and test splits.

  • Line 15–16: We print raw train and test sets.

  • Line 22–25: We initialize StandardScalar, invoking the fit() method to calculate the mean and standard deviation parameters. Next, we use the transform() method to transform the train and test sets.

  • Line 27–28: We print the first transformed dataset.

  • Line 33–35: We initialize another instance of StandardScalar, invoking the fit_transform() method to calculate the mean and standard deviation parameters. Next, we transform the train dataset in a single function call.

  • Line 37–38: We print the second transformed dataset.

  • Line 41–42: We verify that the two transformed datasets are identical.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved