Despite existing a lot of seemingly similar questions none answers my question.
I have a DataFrame
already processed in order to be fed to a DecisionTreeClassifier
and it contains a column label
which is filled with either 0.0
or 1.0
.
I need to bootstrap my data set, by randomly selecting with replacement the same amount of rows for each values of my label
column.
I've looked at all the doc and all I could find are DataFrame.sample(...)
and DataFrameStatFunctions.sampleBy(...)
but the issue with those are that the number of sample retained is not guaranteed and the second one doesn't allow replacement! This wouldn't be an issue on larger data set but in around 50% of my cases I'll have one of the label values that have less than a hundred rows and I really don't want skewed data.
Despite my best efforts, I was unable to find a clean solution to this problem and I resolved myself. to collecting the whole DataFrame and doing the sampling "manually" in Scala before recreating a new DataFrame to train my DecisionTreeClassifier
on. But this seem highly inefficient and cumbersome, I would much rather stay with DataFrame and keep all the benefits coming from that structure.
Here is my current implementation for reference and so you know exactly what I'd like to do:
val nbSamplePerClass = /* some int value currently ranging between 50 and 10000 */
val onesDataFrame = inputDataFrame.filter("label > 0.0")
val zeros = inputDataFrame.except(onesDataFrame).collect()
val ones = onesDataFrame.collect()
val nbZeros = zeros.count().toInt
val nbOnes = ones.count().toInt
def randomIndexes(maxIndex: Int) = (0 until nbSamplePerClass).map(
_ => new scala.util.Random().nextInt(maxIndex)).toSeq
val zerosSample = randomIndexes(nbZeros).map(idx => zeros(idx))
val onesSample = randomIndexes(nbOnes).map(idx => ones(idx))
val samples = scala.collection.JavaConversions.seqAsJavaList(zerosSample ++ onesSample)
val resDf = sqlContext.createDataFrame(samples, inputDataFrame.schema)
Does anyone know how I could implement such a sampling while only working with DataFrames? I'm pretty sure that it would significantly speed up my code! Thank you for your time.