176

I am having a lot of trouble understanding how the class_weight parameter in scikit-learn's Logistic Regression operates.

The Situation

I want to use logistic regression to do binary classification on a very unbalanced data set. The classes are labelled 0 (negative) and 1 (positive) and the observed data is in a ratio of about 19:1 with the majority of samples having negative outcome.

First Attempt: Manually Preparing Training Data

I split the data I had into disjoint sets for training and testing (about 80/20). Then I randomly sampled the training data by hand to get training data in different proportions than 19:1; from 2:1 -> 16:1.

I then trained logistic regression on these different training data subsets and plotted recall (= TP/(TP+FN)) as a function of the different training proportions. Of course, the recall was computed on the disjoint TEST samples which had the observed proportions of 19:1. Note, although I trained the different models on different training data, I computed recall for all of them on the same (disjoint) test data.

The results were as expected: the recall was about 60% at 2:1 training proportions and fell off rather fast by the time it got to 16:1. There were several proportions 2:1 -> 6:1 where the recall was decently above 5%.

Second Attempt: Grid Search

Next, I wanted to test different regularization parameters and so I used GridSearchCV and made a grid of several values of the C parameter as well as the class_weight parameter. To translate my n:m proportions of negative:positive training samples into the dictionary language of class_weight I thought that I just specify several dictionaries as follows:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

and I also included None and auto.

This time the results were totally wacked. All my recalls came out tiny (< 0.05) for every value of class_weight except auto. So I can only assume that my understanding of how to set the class_weight dictionary is wrong. Interestingly, the class_weight value of 'auto' in the grid search was around 59% for all values of C, and I guessed it balances to 1:1?

My Questions

  1. How do you properly use class_weight to achieve different balances in training data from what you actually give it? Specifically, what dictionary do I pass to class_weight to use n:m proportions of negative:positive training samples?

  2. If you pass various class_weight dictionaries to GridSearchCV, during cross-validation will it rebalance the training fold data according to the dictionary but use the true given sample proportions for computing my scoring function on the test fold? This is critical since any metric is only useful to me if it comes from data in the observed proportions.

  3. What does the auto value of class_weight do as far as proportions? I read the documentation and I assume "balances the data inversely proportional to their frequency" just means it makes it 1:1. Is this correct? If not, can someone clarify?

ROMANIA_engineer
  • 54,432
  • 29
  • 203
  • 199
kilgoretrout
  • 3,547
  • 5
  • 31
  • 46
  • 1
    When one uses class_weight, the loss function gets modified. For instance, instead of cross entropy, it become weigted cross entropy. https://towardsdatascience.com/practical-tips-for-class-imbalance-in-binary-classification-6ee29bcdb8a7 – prashanth Sep 25 '19 at 18:01

2 Answers2

174

First off, it might not be good to just go by recall alone. You can simply achieve a recall of 100% by classifying everything as the positive class. I usually suggest using AUC for selecting parameters, and then finding a threshold for the operating point (say a given precision level) that you are interested in.

For how class_weight works: It penalizes mistakes in samples of class[i] with class_weight[i] instead of 1. So higher class-weight means you want to put more emphasis on a class. From what you say it seems class 0 is 19 times more frequent than class 1. So you should increase the class_weight of class 1 relative to class 0, say {0:.1, 1:.9}. If the class_weight doesn't sum to 1, it will basically change the regularization parameter.

For how class_weight="auto" works, you can have a look at this discussion. In the dev version you can use class_weight="balanced", which is easier to understand: it basically means replicating the smaller class until you have as many samples as in the larger one, but in an implicit way.

Andreas Mueller
  • 27,470
  • 8
  • 62
  • 74
  • 1
    Thanks! Quick question: I mentioned recall for clarity and in fact I am trying to decide which AUC to use as my measure. My understanding is that I should be either maximizing area under ROC curve or area under recall vs. precision curve to find parameters. After picking the parameters this way, I believe I choose the threshold for classification by sliding along the curve. Is this what you meant? If so, which of the two curves makes the most sense to look at if my goal is to capture as many TP's as possible? Also, thank you for your work and contributions to scikit-learn!!! – kilgoretrout Jun 22 '15 at 19:35
  • 1
    I think using ROC would be the more standard way to go, but I don't think there will be a huge difference. You do need some criterion to pick the point on the curve, though. – Andreas Mueller Jun 22 '15 at 20:14
  • Isn't there also the idea to more heavily penalise a misclassification of your smaller set in this scenario? Although I agree the thing to try is the balanced setting for class_weight parameter. – Luke Barker Aug 23 '16 at 17:10
  • what did you mean by replicating smaller class implicitly, I guess you are not saying 'upsampling'. Can you elaborate?? @AndreasMueller – N. F. Mar 01 '18 at 04:12
  • 3
    @MiNdFrEaK I think what Andrew means is the estimator replicates samples in the minority class, so that sample of different classes are balanced. It's just oversampling in an implicit way. – Shawn TIAN May 07 '18 at 13:13
  • 12
    @MiNdFrEaK and Shawn Tian: SV-based classifiers **do not** produce more samples of the smaller classes when you use 'balanced'. It literally penalizes mistakes made on the smaller classes. To say otherwise is a mistake and is misleading, especially in large datasets when you cannot afford creating more samples. This answer must be edited. – Pablo Rivas Aug 28 '18 at 12:08
  • Hi Andreas. It is not entirely clear from your answer and from the sklearn doc how class_weight. Does it modify the cost function by putting class weights different than 1 or it oversamples & undersamples the classes? These may have the same results but they are quite distinct methods. – Outcast Jun 17 '19 at 15:14
  • 1
    @PoeteMaudit I opened https://github.com/scikit-learn/scikit-learn/issues/14111 indeed I don't think the docs are very clear. It modifies the loss function. What that means is a bit algorithm-specific, for example for decision trees that means reweighting the criterion (while for gradient boosting it's the actual loss). – Andreas Mueller Jun 17 '19 at 19:13
  • Thanks for your answer Andreas. So it modifies the algorithm/loss-function; that's good to know. I agree with you that the docs should have been a bit clearer there since for example changing the class weights by modifying the algorithm/loss-function or oversampling/undersampling are pretty different methods (in terms of how the data are utilised by the algorithm and the potential repercussions of that). – Outcast Jun 18 '19 at 10:40
  • 5
    https://scikit-learn.org/dev/glossary.html#term-class-weight Class weights will be used differently depending on the algorithm: for linear models (such as linear SVM or logistic regression), the class weights will alter the loss function by weighting the loss of each sample by its class weight. For tree-based algorithms, the class weights will be used for reweighting the splitting criterion. Note however that this rebalancing does not take the weight of samples in each class into account. – prashanth Sep 25 '19 at 18:16
  • @Andreas i find from documentation that for tree-based algorithms it weights splitting criterion. When i use the balanced option and i visualize the the tree, it shows decimal number against value parameter. are the values symbolic? – keramat Oct 26 '20 at 21:13
  • These are the weighted sample counts. If that's not explained in the docs, please open an issue so we can add it. – Andreas Mueller Oct 27 '20 at 01:25
  • So lots of theory here, can anyone tell me how does class_weight improvise the loss mathematically ? – Aman Dalmia Jul 23 '21 at 18:47
26

The first answer is good for understanding how it works. But I wanted to understand how I should be using it in practice.

Use imbalanced-learn

For imbalanced data the methods in imbalanced-learn produce better results, in and especially out sample, than using the class weight param.

SUMMARY

  • for moderately imbalanced data WITHOUT noise, there is not much of a difference in applying class weights
  • for moderately imbalanced data WITH noise and strongly imbalanced, it is better to apply class weights
  • param class_weight="balanced" works decent in the absence of you wanting to optimize manually
  • with class_weight="balanced" you capture more true events (higher TRUE recall) but also you are more likely to get false alerts (lower TRUE precision)
    • as a result, the total % TRUE might be higher than actual because of all the false positives
    • AUC might misguide you here if the false alarms are an issue
  • no need to change decision threshold to the imbalance %, even for strong imbalance, ok to keep 0.5 (or somewhere around that depending on what you need)

NB

The result might differ when using RF or GBM. sklearn does not have class_weight="balanced" for GBM but lightgbm has LGBMClassifier(is_unbalance=False)

CODE

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
citynorman
  • 4,918
  • 3
  • 38
  • 39