14

I've got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:

A   B   n  
1   2   1  
2   9   1  
3   8   2    
4   1   1    
5   3   3 

And transform like this:

A   B   n  
1   2   1  
2   9   1  
3   8   2
3   8   2       
4   1   1    
5   3   3 
5   3   3 
5   3   3 

I think I should use explode, but I don't understand how it works...
Thanks

Matheus Lacerda
  • 5,983
  • 11
  • 29
  • 45
Chjul
  • 399
  • 2
  • 3
  • 10

3 Answers3

21

With Spark 2.4.0+, this is easier with builtin functions: array_repeat + explode:

from pyspark.sql.functions import expr

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)], ["A", "B", "n"])

new_df = df.withColumn('n', expr('explode(array_repeat(n,int(n)))'))

>>> new_df.show()
+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
|  5|  3|  3|
|  5|  3|  3|
+---+---+---+
jxc
  • 13,553
  • 4
  • 16
  • 34
  • When using the `array_repeat` API function in pyspark, how can you reference n for the second parameter (count)? I get `Column is not iterable` when trying with `F.col()`. – David Foster Dec 21 '21 at 11:28
  • 1
    @DavidFoster, this can not be done using pyspark API functions, check this *Frequent* [question](https://stackoverflow.com/q/51140470/9510729). using SQL expressions, this is not a problem and IMO the code is often more concise and easy to maintain. – jxc Dec 21 '21 at 13:47
  • thanks for confirming that it's not supported. It seems odd that the Python API doesn't support this when it's simple in the underlying SQL. That's a fair point on brevity, however, in this case it looks just like the Python solution would and my IDE will treat it as a string rather than functions/methods etc. – David Foster Dec 22 '21 at 15:31
  • @jxc array_repeat does not work on long type column – Dariusz Krynicki Aug 03 '22 at 06:53
9

The explode function returns a new row for each element in the given array or map.

One way to exploit this function is to use a udf to create a list of size n for each row. Then explode the resulting array.

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
    
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
+---+---+---+

# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))

+---+---+---------+
|  A|  B|        n|
+---+---+---------+
|  1|  2|      [1]|
|  2|  9|      [1]|
|  3|  8|   [2, 2]|
|  4|  1|      [1]|
|  5|  3|[3, 3, 3]|
+---+---+---------+ 

# now use explode  
df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 
| A | B | n | 
+---+---+---+ 
|  1|  2|  1| 
|  2|  9|  1| 
|  3|  8|  2| 
|  3|  8|  2| 
|  4|  1|  1| 
|  5|  3|  3| 
|  5|  3|  3| 
|  5|  3|  3| 
+---+---+---+ 
Mehdi LAMRANI
  • 11,289
  • 14
  • 88
  • 130
Ahmed
  • 689
  • 6
  • 14
3

I think the udf answer by @Ahmed is the best way to go, but here is an alternative method, that may be as good or better for small n:

First, collect the maximum value of n over the whole DataFrame:

max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3

Now create an array for each row of length max_n, containing numbers in range(max_n). The output of this intermediate step will result in a DataFrame like:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#|  A|  B|  n|  n_array|
#+---+---+---+---------+
#|  1|  2|  1|[0, 1, 2]|
#|  2|  9|  1|[0, 1, 2]|
#|  3|  8|  2|[0, 1, 2]|
#|  4|  1|  1|[0, 1, 2]|
#|  5|  3|  3|[0, 1, 2]|
#+---+---+---+---------+

Now we explode the n_array column, and filter to keep only the values in the array that are less than n. This will ensure that we have n copies of each row. Finally we drop the exploded column to get the end result:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
    .select('A', 'B', 'n', f.explode('n_array').alias('col'))\
    .where(f.col('col') < f.col('n'))\
    .drop('col')\
    .show()
#+---+---+---+
#|  A|  B|  n|
#+---+---+---+
#|  1|  2|  1|
#|  2|  9|  1|
#|  3|  8|  2|
#|  3|  8|  2|
#|  4|  1|  1|
#|  5|  3|  3|
#|  5|  3|  3|
#|  5|  3|  3|
#+---+---+---+

However, we are creating a max_n length array for each row- as opposed to just an n length array in the udf solution. It's not immediately clear to me how this will scale vs. udf for large max_n, but I suspect the udf will win out.

pault
  • 41,343
  • 15
  • 107
  • 149