2

What is the way of implementing Batch gradient descent using sklearn for classification? We have SGDClassifier for Stochastic GD which will take single instance at a time and Linear/Logistic Regression which uses normal equation.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
k92
  • 375
  • 3
  • 15
  • Check https://stackoverflow.com/questions/15036630/batch-gradient-descent-with-scikit-learn-sklearn (although perhaps not exactly what you are asking) – desertnaut Apr 06 '19 at 18:34
  • A negative answer (i.e. showing that you can actually *not* do something), as the one below, is a valid answer indeed. I kindly suggest you accept it. – desertnaut May 27 '20 at 16:17

1 Answers1

4

The possible answer to the question as pointed out in the other similar question as well from sklearn docs:

SGD allows minibatch (online/out-of-core) learning, see the partial_fit method.

But is partial_fit really a batch gradient decent?

SGD: The gradient of the cost function is calculated and the weights are updated using the gradient decent step for each sample.

Batch/Mini Batch GD: The gradient of the cost function is calculated and the weights are updated using the gradient decent step once per batch.

So Batch GD with batch size of 1 == SGD.

Now that we are clear about definitions lets investigate the code of sklearn SGDClassifier.

The docstring of partial_fit says

Perform one epoch of stochastic gradient descent on given samples.

But this is not a batch GD but it looks more like a helper function to run fit method with max_iter=1 (infact commented as same in docstrings).

partial_fit calls _partial_fit with max_iter==1. Reference link

fit method calls _fit which calls _partial_fit with max_iter set to the assigned\default maximum iterations. Reference link

conclusion:

partial_fit does not really do batch GD, i.e it is not calculating the gradients and updating the weight per batch but rather doing so for each sample.

There seems to be no mechanism in sklearn to do batch gradient descend.

Ian
  • 119
  • 1
  • 6
mujjiga
  • 16,186
  • 2
  • 33
  • 51