I think the udf
answer by @Ahmed is the best way to go, but here is an alternative method, that may be as good or better for small n
:
First, collect the maximum value of n
over the whole DataFrame:
max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3
Now create an array for each row of length max_n
, containing numbers in range(max_n)
. The output of this intermediate step will result in a DataFrame like:
df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#| A| B| n| n_array|
#+---+---+---+---------+
#| 1| 2| 1|[0, 1, 2]|
#| 2| 9| 1|[0, 1, 2]|
#| 3| 8| 2|[0, 1, 2]|
#| 4| 1| 1|[0, 1, 2]|
#| 5| 3| 3|[0, 1, 2]|
#+---+---+---+---------+
Now we explode the n_array
column, and filter to keep only the values in the array that are less than n
. This will ensure that we have n
copies of each row. Finally we drop the exploded column to get the end result:
df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
.select('A', 'B', 'n', f.explode('n_array').alias('col'))\
.where(f.col('col') < f.col('n'))\
.drop('col')\
.show()
#+---+---+---+
#| A| B| n|
#+---+---+---+
#| 1| 2| 1|
#| 2| 9| 1|
#| 3| 8| 2|
#| 3| 8| 2|
#| 4| 1| 1|
#| 5| 3| 3|
#| 5| 3| 3|
#| 5| 3| 3|
#+---+---+---+
However, we are creating a max_n
length array for each row- as opposed to just an n
length array in the udf
solution. It's not immediately clear to me how this will scale vs. udf
for large max_n
, but I suspect the udf
will win out.