I have the following methods, part of the logic for performing stratified k-fold crossvalidation.
private static IEnumerable<IEnumerable<int>> GenerateFolds(
IClassificationProblemData problemData, int numberOfFolds)
{
IRandom random = new MersenneTwister();
IEnumerable<double> values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
var valuesIndices =
problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v });
IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass =
valuesIndices.GroupBy(x => x.Value, x => x.Index)
.Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
var enumerators = foldsByClass.Select(x => x.GetEnumerator()).ToList();
while (enumerators.All(e => e.MoveNext()))
{
var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
yield return fold.ToList();
}
}
Folds generation:
private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(
IEnumerable<T> values, int valuesCount, int numberOfFolds)
{
// number of folds rounded to integer and remainder
int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds;
int start = 0, end = f;
for (int i = 0; i < numberOfFolds; ++i)
{
if (r > 0)
{
++end;
--r;
}
yield return values.Skip(start).Take(end - start);
start = end;
end += f;
}
}
The generic GenerateFolds<T
method simply splits an IEnumerable<T>
into a sequence of IEnumerable
s according to the specified number of folds. For example, if I had 101 training samples, it would generate one fold of size 11 and 9 folds of size 10.
The method above it groups the samples based on class values, splits each group into the specified number of folds and then joins the by-class folds into the final folds, ensuring the same distribution of class labels.
My question regards the line yield return fold.ToList()
. As it is, the method works correctly, if I remove the ToList()
however, the results are no longer correct. In my test case I have 641 training samples and 10 folds, which means the first fold should be of size 65 and the remaining folds of size 64. But when I remove ToList()
, all the folds are of size 64 and class labels are not correctly distributed. Any ideas why? Thank you.