fit_transform
methodThe fit_transform
method from the sklearn.preprocessing
class is used to preprocess the data for model training. The fit()
method calculates the various required parameters, and the transform()
method applies the calculated parameters to standardize the data. In the case of the StandardScaler()
, the mean and the standard deviation are calculated and the data is scaled and centered to have a mean of 0 and a standard deviation of 1. Calling fit_transform()
once is equivalent to calling fit()
and then transform()
on the data.
Preprocessing is necessary before training any model on a given dataset. For example during training neural networks via gradient descent. If one dimension is much greater than the other, we would have the problem of over-shooting the minima, and our model will fail to converge.
In comparison, gradient descent converges better when the dimensions have been scaled to have a standard deviation of 1.
The standardization formula is given by
When fit_transform()
method is called, the standard deviation (
Note: We only pre-process the training data with
fit_transform()
. The testing dataset is preprocessed bytransform()
as its distribution is unknown, and is assumed to be the same as that of the training dataset.
In the following playground, we'll see how to use the fit_transform()
method, and how it is equivalent to using fit()
and transform()
together.
# import relavant librariesfrom sklearn import datasetsfrom sklearn.model_selection import train_test_splitfrom sklearn.preprocessing import StandardScaler# load the iris datasetiris = datasets.load_iris()X = iris.datay = iris.target# create train and test splitsX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)# print the original datasetprint('Original Train: \n', X_train[0:5])print('Original Test: \n', X_test[0:5])print("-"*50)print('Using transform()')# use fit() and transform separatelystd_slc = StandardScaler()std_slc.fit(X_train)X_train_std = std_slc.transform(X_train)X_test_std = std_slc.transform(X_test)print('Transformed Train: \n', X_train_std[0:5])print('Transformed Test: \n', X_test_std[0:5])print("-"*50)print('Using fit_transform()')# use fit and transform in a single function callstd_slc2 = StandardScaler()X_train_std2 = std_slc2.fit_transform(X_train)X_test_std2 = std_slc2.transform(X_test)print('Transformed Train: \n', X_train_std[0:5])print('Transformed Test: \n', X_test_std[0:5])# verify that using fit_transform() equates to using fit() and transform() togetherif (X_train_std2 == X_train_std).all() and (X_test_std2 == X_test_std).all():print ('both are equivalent')
Line 2–4: We import the relevant libraries.
Line 7–9: We load the dataset.
Line 12: We create train
and test
splits.
Line 15–16: We print raw train
and test
sets.
Line 22–25: We initialize StandardScalar
, invoking the fit()
method to calculate the mean and standard deviation parameters. Next, we use the transform()
method to transform the train
and test
sets.
Line 27–28: We print the first transformed dataset.
Line 33–35: We initialize another instance of StandardScalar
, invoking the fit_transform()
method to calculate the mean and standard deviation parameters. Next, we transform the train
dataset in a single function call.
Line 37–38: We print the second transformed dataset.
Line 41–42: We verify that the two transformed datasets are identical.
Free Resources