0

I’m trying to replicate some basic linear regression results from sci-kit learn’s LASSO implementation into various pyTorch linear and non-linear models (dense & LSTM) and finding that the solution quality is nowhere near as good. I’ve tried this two separate ways, first by explicitly adding a penalty to the loss function, similar to what is described here: python - Pytorch: how to add L1 regularizer to activations? - Stack Overflow

The second way I’ve tried is using AdamW and effectively treating lambda (weight decay) the same way it’s used in LASSO (creating curves of train and cross-validation error as a function of lambda).

In both cases, I get sub-optimal results vs. sci-kit learn (i.e. the validation and test set objective functions are higher). Also, I get a substantial level of noise in the training, even with a lower learning rate (1.0e-4) and a fairly high weight decay:

Train Loss vs. epoch

Any thoughts on how to get the neural network models closer to the linear regressions results and smooth out the optimizer solutions?

DrJubbs
  • 11
  • 2
  • I did not get your target. What scikit-learn implementation are you talking about? The traditional one for linear models? If so, why are you using neural nets to fit a linear model? If you are targeting simple L1 regularization for neural nets, just add it to your loss function as the linear case and let backprop work it out for you – SystemSigma_ Apr 07 '23 at 13:55
  • Yes, the traditional one `sklearn.linear_model.Lasso`. I'm fitting a linear model as a baseline. The goal would be to out-perform the linear model using either a deep neural network or LSTM model. But I'm being a good data scientist and comparing myself against a trivial linear model first. However, I can't seem to better or match the linear model, even when using a simple linear network in pyTorch. I did add the L1 penalty to the loss function, and did backprop, and the solution quality is significantly worse than that obtained from scikit. – DrJubbs Apr 08 '23 at 14:15
  • If the model is truly linear, you will not outperform it with nonlinear models. However, it seems to be an implementation issue. Maybe posting some code would help debug it – SystemSigma_ Apr 09 '23 at 17:43

0 Answers0