0

I need to split a dataframe in spark using scala based on a give ratio. This should be done on the sorted version of dataframe by a particular column named ts. The first ratio is used for a training and the lats part is going to be used for validation.

val dataframe=//a sample dataframe
val trainRatio=0.8;
val training=//dataframe.rdd.orderBy("ts")
val test=//

Can someone give me any hint how to do that?

Luckylukee
  • 575
  • 2
  • 9
  • 27

3 Answers3

0

First the sort the df by column which you want

val sortdf =df.sort($"ts".desc)

val Array(training_data, validat_data) = sortdf .randomSplit(Array(0.8,0.2))
learner
  • 344
  • 2
  • 22
  • this is what you need or anything else ?? – learner May 02 '17 at 09:18
  • suppose, I have an RDD of (1,2,3,4,5,6,7,8,9,10)..using ratio of 0.8, I should get an RDD of (1,2,3,4,5,6,7,8) for training and another RDD of (9,10) for test, your suggested answer gives me again random values. – Luckylukee May 02 '17 at 10:36
  • @learner I have only 11 values and I did split on [0.6, 0.3, 0.1] , but its dividing it in [6,5,0] or [8,3,0] I don't need zero as 11 can still be divided as [6,3,2] Is there any way to check to not get zero after split in train,test and valid – vipin Jun 28 '18 at 14:02
0

Guessing my answer is here, first, I need to find a percentile value in my RDD and then some simple mapping function will divide the RDD correctly.

Community
  • 1
  • 1
Luckylukee
  • 575
  • 2
  • 9
  • 27
0

You can try this below code but any good solution is much appreciated.

val conf = new SparkConf().setAppName("testApp").setMaster("local")
val sc = new SparkContext(conf)
val data = sc.parallelize(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
val count = data.count()
val trainRatio = 0.8
val trainSize = math.round(count * trainRatio).toInt

val trainingRdd = data
  .zipWithIndex()
  .filter { case (_, index) => index < trainSize }
  .map { case (row, _) => row }
trainingRdd.foreach(println)
val testRdd = data
  .zipWithIndex()
  .filter { case (_, index) => index >= trainSize }
  .map { case (row, _) => row }
testRdd.foreach(println)
Souvik
  • 377
  • 4
  • 16