0

I would like to replicate rows according to their value for a given column. For example, I got this DataFrame:

+-----+
|count|
+-----+
|    3|
|    1|
|    4|
+-----+

I would like to get:

+-----+
|count|
+-----+
|    3|
|    3|
|    3|
|    1|
|    4|
|    4|
|    4|
|    4|
+-----+

I tried to use withColumn method, according to this answer.

val replicateDf = originalDf
    .withColumn("replicating", explode(array((1 until $"count").map(lit): _*)))
    .select("count")

But $"count" is a ColumnName and cannot be used to represent its values in the above expression.

(I also tried with explode(Array.fill($"count"){1}) but same problem here.)

What do I need to change? Is there a cleaner way?

thebluephantom
  • 16,458
  • 8
  • 40
  • 83
Baptiste Merliot
  • 841
  • 11
  • 24

2 Answers2

3

You can use array_repeat function:

import org.apache.spark.sql.functions.{array_repeat, explode}

val df = Seq(1, 2, 3).toDF

df.select(explode(array_repeat($"value", $"value"))).show()
+---+
|col|
+---+
|  1|
|  2|
|  2|
|  3|
|  3|
|  3|
+---+
10465355
  • 4,481
  • 2
  • 20
  • 44
3

array_repeat is available from 2.4 onwards. If you need the solution in lower versions, you can use udf() or rdd. For Rdd, check this out

import scala.collection.mutable._

val df = Seq(3,1,4).toDF("count")
val rdd1 = df.rdd.flatMap( x=> { val y = x.getAs[Int]("count"); for ( p <- 0 until y ) yield Row(y) }  )
spark.createDataFrame(rdd1,df.schema).show(false)

Results:

+-----+
|count|
+-----+
|3    |
|3    |
|3    |
|1    |
|4    |
|4    |
|4    |
|4    |
+-----+

With df() alone

scala> df.flatMap( r=> { (0 until r.getInt(0)).map( i => r.getInt(0)) } ).show
+-----+
|value|
+-----+
|    3|
|    3|
|    3|
|    1|
|    4|
|    4|
|    4|
|    4|
+-----+

For udf(), below would work

val df = Seq(3,1,4).toDF("count")
def array_repeat(x:Int):Array[Int]={
  val y = for ( p <- 0 until x )yield x
  y.toArray
}
val udf_array_repeat = udf (array_repeat(_:Int):Array[Int] )
df.withColumn("count2", explode(udf_array_repeat('count))).select("count2").show(false)

EDIT :

Check @user10465355's answer below for more information about array_repeat.

stack0114106
  • 8,534
  • 3
  • 13
  • 38