Curve fitting is a statistical technique to find a mathematical model that best fits a set of data points. The goal is to approximate the relationship between variables by adjusting the parameters of a chosen function or curve. The need for curve fitting arises from the inherent variability and noise present in real-world data. When we collect data, it may not perfectly follow a known mathematical equation or model. Curve fitting allows us to identify a mathematical representation that captures the underlying trends or patterns in the data, making it easier to make predictions, analyze relationships between variables, and gain insights into the system’s behavior.
The curve that is the best fit for a dataset may be a straight line, a polynomial curve, an exponential curve, or an even more complex mathematical function. The curve_fit()
function in Python’s SciPy library provides a generic implementation to fit a curve of our choice on data using the least squares method.
The method of fitting a curve on data using SciPy can be broken down into the following steps:
Estimate the curve that fits the data best and make a function for it.
Ask the curve_fit()
function to find the parameters that best fit the predicted curve onto the data.
Make a graph with a higher resolution to fit the data.
This Answer will look into some curve-fitting examples using this function. Let’s start with linear curve fitting and explore the steps in detail.
The most straightforward implementation of curve fitting is when the data follows a linear trend. This sort of trend is defined by the equation of a straight line:
Where
The code below shows some sample data that can be fitted on a straight line. The data.csv
file contains the data that we need to fit on a curve. It has two rows; the first row contains data for the x-values, while the second row contains data for the y-values.
-5.000000000000000000e+00 -4.599999999999999645e+00 -4.199999999999999289e+00 -3.799999999999998934e+00 -3.399999999999998579e+00 -2.999999999999998224e+00 -2.599999999999997868e+00 -2.199999999999997513e+00 -1.799999999999997158e+00 -1.399999999999996803e+00 -9.999999999999964473e-01 -5.999999999999960920e-01 -1.999999999999957367e-01 2.000000000000046185e-01 6.000000000000049738e-01 1.000000000000005329e+00 1.400000000000005684e+00 1.800000000000006040e+00 2.200000000000006395e+00 2.600000000000006750e+00 3.000000000000007105e+00 3.400000000000007461e+00 3.800000000000007816e+00 4.200000000000008171e+00 4.600000000000008527e+00-1.200393295214232303e+01 -1.233913950615116484e+01 -9.661010189588212782e+00 -7.307300632638694005e+00 -7.184871325883614546e+00 -5.802418833159632250e+00 -6.552636647945480064e+00 -5.165594306221565013e+00 -5.381493519580181406e+00 -4.023389714649279192e+00 -2.278417567863197934e+00 -1.903660355386104452e+00 -1.404404506946818731e+00 -5.327857134101845471e-01 2.146006765500340641e-01 7.756569205825256663e-01 2.166547047761087530e+00 2.756784610767706756e+00 4.196571876840264004e+00 4.488804319821260158e+00 4.576823340806751794e+00 5.496632751611786105e+00 6.371882782636471454e+00 5.553759977045843677e+00 7.073528897432594498e+00
Line 4: Reading data from the file.
Lines 6–7: Storing data read into separate variables for the x- and y-values.
Line 9: Plotting the extracted values.
Taking a look at the points gives us a perception of a linear trend. Now, we need to make a generic function, let’s call it model_function
, for a linear line. This function implements a straight line equation.
The first argument of the model_function
must be a list of values for the independent variable, and the rest of the arguments should be the parameters that need to be modified to make the perfect curve fit (
import numpy as npimport matplotlib.pyplot as pltdef model_function(x, m, c):return m*x + cdata = np.loadtxt("data.csv")x = data[0]y = data[1]plt.scatter(x, y)plt.savefig("/usercode/output/plot.png")
In the code above, on lines 4–5, we define a generic function that can fit the shape of the trend followed by our data, a straight line in this case.
Now that we have estimated the function that can fit our data, it’s time that we provide the generic function and the data points to the curve_fit()
function. This function will use the least square method to find the best parameter values that modify the generic function shape to fit best with the data provided. The curve_fit()
function has two output arguments .i.e., a list of optimal values for the function parameters and a 2D array containing covariance in the chosen values of those parameters. To plot the best-fitting curve, we will just need the optimal parameters. Run the code below to see the optimal values for our data.
import numpy as npimport matplotlib.pyplot as pltfrom scipy.optimize import curve_fitdef model_function(x, m, c):return m*x + cdata = np.loadtxt("data.csv")x = data[0]y = data[1]par, cov = curve_fit(model_function, x, y)print("Optimal paramters chosen: m = {:.2f}, c = {:.2f}".format(*par))
Line 13: Calling the curve_fit()
function to find the optimal curve fitting parameters. The parameters are stored in a list format in the variable named par
.
Line 15: Printing the optimal parameters found by the curve_fit()
function.
The next and final step is to make a high-resolution plot using the optimal parameters and the generic function definition. The code below shows how to do it.
import numpy as npimport matplotlib.pyplot as pltfrom scipy.optimize import curve_fitdef model_function(x, m, c):return m*x + cdata = np.loadtxt("data.csv")x = data[0]y = data[1]par, cov = curve_fit(model_function, x, y)x_fit = np.arange(min(x), max(x), 0.1)y_fit = model_function(x_fit, *par)plt.scatter(x, y, label = "Data points")plt.plot(x_fit, y_fit, label = "Best fit")plt.legend()plt.savefig("/usercode/output/plot.png")
Line 15: Defining a range of values of the independent variable from the minimum to the maximum value of the dataset in steps of 0.1.
Line 16: Generating the corresponding values of the dependent variable based on the model function with the best-fit parameters.
Lines 18–19: Plotting data points as a scatter plot and the best-fit points as a line plot.
Using these steps, we can also find a fitting curve for datasets demonstrating more complex mathematical relationships. In the codes provided below, we have implemented curve fitting for a dataset following a quadratic trend and another one following a sinusoidal trend.
For data following a quadratic trend, we just need to modify the model function a bit. Notice that we now have 3 parameters instead of 2 in the linear case, but the curve_fit()
can handle it well without any alteration in the code.
import numpy as npimport matplotlib.pyplot as pltfrom scipy.optimize import curve_fitdef model_func(x, a, b, c):return a*(x+b)**2 + cdata = np.loadtxt("quad_data.csv")x = data[0]y = data[1]par, cov = curve_fit(model_func, x, y)x_fit = np.arange(-7, 4.8, 0.05)y_fit = model_func(x_fit, *par)plt.scatter(x, y, label = "Data points")plt.plot(x_fit, y_fit, label = "Best fit")plt.legend()plt.show()
Lines 5–6: Definition of a model function to fit on quadratic data.
This time, the data follows a sinusoidal trend, and again, we just need to modify the model function a bit.
import numpy as npimport matplotlib.pyplot as pltfrom scipy.optimize import curve_fitdef model_func(x, a, b, c):return a * np.sin(x+b) + cdata = np.loadtxt("sin_data.csv")x = data[0]y = data[1]par, cov = curve_fit(model_func, x, y)x_fit = np.arange(0, 10, 0.05)y_fit = model_func(x_fit, *par)plt.scatter(x, y, label = "Data points")plt.plot(x_fit, y_fit, label = "Best fit")plt.legend()plt.show()
Lines 5–6: Definition of a model function to fit on sinusoidal data.
We have covered the basics of curve fitting in this answer, which should suffice for most everyday coding tasks. However, for complex datasets, one may need some additional control, like providing an initial estimate of the fitting parameters. For exploring such options, refer to the curve_fit()
function’s
Free Resources