3

Despite existing a lot of seemingly similar questions none answers my question.

I have a DataFrame already processed in order to be fed to a DecisionTreeClassifier and it contains a column label which is filled with either 0.0 or 1.0.

I need to bootstrap my data set, by randomly selecting with replacement the same amount of rows for each values of my label column.

I've looked at all the doc and all I could find are DataFrame.sample(...) and DataFrameStatFunctions.sampleBy(...) but the issue with those are that the number of sample retained is not guaranteed and the second one doesn't allow replacement! This wouldn't be an issue on larger data set but in around 50% of my cases I'll have one of the label values that have less than a hundred rows and I really don't want skewed data.

Despite my best efforts, I was unable to find a clean solution to this problem and I resolved myself. to collecting the whole DataFrame and doing the sampling "manually" in Scala before recreating a new DataFrame to train my DecisionTreeClassifier on. But this seem highly inefficient and cumbersome, I would much rather stay with DataFrame and keep all the benefits coming from that structure.

Here is my current implementation for reference and so you know exactly what I'd like to do:

val nbSamplePerClass = /* some int value currently ranging between 50 and 10000 */

val onesDataFrame = inputDataFrame.filter("label > 0.0")
val zeros = inputDataFrame.except(onesDataFrame).collect()
val ones = onesDataFrame.collect()

val nbZeros = zeros.count().toInt
val nbOnes = ones.count().toInt

def randomIndexes(maxIndex: Int) = (0 until nbSamplePerClass).map(
    _ => new scala.util.Random().nextInt(maxIndex)).toSeq

val zerosSample = randomIndexes(nbZeros).map(idx => zeros(idx))
val onesSample = randomIndexes(nbOnes).map(idx => ones(idx))
val samples = scala.collection.JavaConversions.seqAsJavaList(zerosSample ++ onesSample)
val resDf = sqlContext.createDataFrame(samples, inputDataFrame.schema)

Does anyone know how I could implement such a sampling while only working with DataFrames? I'm pretty sure that it would significantly speed up my code! Thank you for your time.

Tristan O.
  • 51
  • 7
  • 1
    I've already answer a similar question [here](http://stackoverflow.com/questions/32238727/stratified-sampling-in-spark). Did you take a look at that ? – eliasah Jul 21 '16 at 16:24
  • 1
    I missed this post, I'm diving into it! I will report back! Thanks – Tristan O. Jul 21 '16 at 16:29
  • 1
    It's an RDD-based answer but it should do the job. Let me know what you think and don't forget to upvote if it helps you ! :) – eliasah Jul 21 '16 at 16:31
  • 2
    Okay, I does what I want to do and it seems like a better approach than Scala Arrays! It's a big step up! I'll use your solution for now! Thanks. But creating the PairRDD and transforming it back to a DataFrame is still quite cumbersome! I'd really prefer if there was a way to do such a thing with DataFrame directly! I hope someone has that answer. :) In any case, thank you for this improvement and your time. – Tristan O. Jul 21 '16 at 16:49

0 Answers0