0

I would like to conduct a randomized search for optimal parameters with XGB with additional augmented data in the train set of each fold of a cv. My current way is to add the augmented data in the split of my cross validation method after getting a random parameter dictionary from ParameterSampler. Like so:

cv = RepeatedStratifiedKFold(
    n_splits=5, 
    n_repeats=5, 
    random_state=101
    )

for params in random_param_samples:
    model = XGBClassifier(**params)
    for i, (train, test) in enumerate(cv.split(X, y)):
        new_X = X.iloc[train,:] + augmented_data.iloc[train,:] # this is simplified just to show the idea
        new_y = y.iloc[train] + augmented_data.iloc[train] # this is simplified just to show the idea
        model.fit(
            new_X,
            new_y)
        y_pred = model.predict(X.iloc[test,:])
    # collect my scores

It is working, but it creates a lot of overhead everywhere.

I was wondering, if there is a smart way to tell RandomizedSearchCV and basically every method in sklearn accepting a cross validation method to train with augmented data and test on original data without augmentations. It would be great to be able to use cross_validate and similar methods with augmented data.

Maybe passing the original data and the augmented data as one DataFrame and tell the cv method that only certain train indices are allowed and only test indices from the original data?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
nilsfl
  • 71
  • 1
  • 7
  • 1
    There are many ways to achieve it. First is you split the data in such a way that the only train data contains augmented data (using something like [PredefinedSplit](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.PredefinedSplit.html)). Or you can augment your data on the fly using a [Pipeline](https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.pipeline.Pipeline.html). Note that this is imbalanced-learn Pipeline which can handle training data increase or decrease without touching test data. Try these and if still get any issues, update the question. – Vivek Kumar Jun 12 '20 at 11:43
  • Thanks for your comment. PredefinedSplit sounds very helpful. It means that I could build my own train-indices and test-indices from a DataFrame with augmented data and pass the PredefinedSplit.split() to my RandomizedSearchCV? Regarding the Pipeline, this would be inside my fit loop for random parameters, so only changing my method a bit but won't work directly with RandomizedSearchCV, or am I missing something? – nilsfl Jun 12 '20 at 11:57
  • You can use the pipeline inside the `RandomizedSearchCV` instead of `XGBClassifier` – Vivek Kumar Jun 12 '20 at 11:59
  • I understand that I can pass a Pipeline to `RandomizedSearchCV`, but lets say I have a Pipeline with `SMOTE()` and `XGBClassifier`. I will still have to pass a cv method to `RandomizedSearchCV` and that is the point were I don't get how `RandomizedSearchCV` knows that it should only test on original data and train on original and augmented. – nilsfl Jun 12 '20 at 12:15
  • No, when you have your Pipeline, then the data will be divided first and then the training data will be passed to `SMOTE` for augmentation. The test data will be not be augmented, only passed directly to `XGBClassifier`. Irrespective of the `cv` type you pass – Vivek Kumar Jun 12 '20 at 12:22
  • If still not understand, I will add an answer – Vivek Kumar Jun 12 '20 at 12:23
  • Ah, so this is special about the imblearn Pipeline I guess? This is actually great! You don't need to add an answer. I got it :) – nilsfl Jun 12 '20 at 12:34
  • 1
    Possibly helpful: [Correct way to do cross validation in a pipeline with imbalanced data](https://stackoverflow.com/questions/62308095/correct-way-to-do-cross-validation-in-a-pipeline-with-imbalanced-data) – desertnaut Jun 12 '20 at 13:59
  • My answer explaining the imblearn pipeline working: https://stackoverflow.com/questions/50245684/using-smote-with-gridsearchcv-in-scikit-learn/50245954#50245954 – Vivek Kumar Jun 12 '20 at 16:28

0 Answers0