LinearRegression
fits a linear model to data. In the case of one-dimensional X
values like you have above, the results is a straight line (i.e. y = a + b*x
). In the case of two-dimensional values, the result is a plane (i.e. z = a + b*x + c*y
). So you can't expect a linear regression model to perfectly fit a quadratic curve: it simply doesn't have enough model complexity to do that.
That said, you can cleverly transform your input data in order to fit a quadratic curve with a linear regression model. Consider the 2D case above:
z = a + b*x + c*y
Now let's make the substitution y = x^2
. That is, we add a second dimension to our data which contains the quadratic term. Now we have another linear model:
z = a + b*x + c*x^2
The result is a model that is quadratic in x
, but still linear in the coefficients! That is, we can solve it easily via a linear regression: this is an example of a basis function expansion of the input data. Here it is in code:
import numpy as np
from sklearn.linear_model import LinearRegression
x = np.arange(10)[:, None]
y = np.ravel(x) ** 2
p = np.array([1, 2])
model = LinearRegression().fit(x ** p, y)
model.predict(11 ** p)
# [121]
This is a bit awkward, though, because the model requires 2D input to predict()
, so you have to transform the input manually. If you want this transformation to happen automatically, you can use e.g.PolynomialFeatures
in a pipeline:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
model = make_pipeline(PolynomialFeatures(2), LinearRegression())
model.fit(x, y).predict(11)
# [121]
This is one of the beautiful things about linear models: using basis function expansion like this, they can be very flexible, while remaining very fast! You could think about adding columns with cubic, quartic, or other terms, and it's still a linear regression. Or for periodic models, you might think about adding columns of sines, cosines, etc. In the extreme limit of this, the so-called "kernel trick" allows you to effectively add an infinite number of new columns to your data, and end up with a model that is very powerful – but still linear and thus still relatively fast! For an example of this type of estimator, take a look at scikit-learn's KernelRidge
.