unstack
or pivot
in pandas
is equivalent to pivot
in pyspark
- the equivalent of
stack
or melt
in pandas is explode
in pyspark
Let's start with a sample dataframe:
from datetime import datetime
df = spark.createDataFrame(
[[datetime(y,1,1), a, b] for y, a, b
in zip(range(2000, 2010), range(10), range(10, 20))],
['date', 'a', 'b'])
df.show()
+-------------------+---+---+
| date| a| b|
+-------------------+---+---+
|2000-01-01 00:00:00| 0| 10|
|2001-01-01 00:00:00| 1| 11|
|2002-01-01 00:00:00| 2| 12|
|2003-01-01 00:00:00| 3| 13|
|2004-01-01 00:00:00| 4| 14|
|2005-01-01 00:00:00| 5| 15|
|2006-01-01 00:00:00| 6| 16|
|2007-01-01 00:00:00| 7| 17|
|2008-01-01 00:00:00| 8| 18|
|2009-01-01 00:00:00| 9| 19|
+-------------------+---+---+
1. Unstack
To unstack the dataframe:
import pyspark.sql.functions as psf
df_unstack = df.groupBy("date").pivot("a").agg(psf.max("b"))
df_unstack.show()
+-------------------+----+----+----+----+----+----+----+----+----+----+
| date| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|
+-------------------+----+----+----+----+----+----+----+----+----+----+
|2003-01-01 00:00:00|null|null|null| 13|null|null|null|null|null|null|
|2004-01-01 00:00:00|null|null|null|null| 14|null|null|null|null|null|
|2009-01-01 00:00:00|null|null|null|null|null|null|null|null|null| 19|
|2001-01-01 00:00:00|null| 11|null|null|null|null|null|null|null|null|
|2006-01-01 00:00:00|null|null|null|null|null|null| 16|null|null|null|
|2008-01-01 00:00:00|null|null|null|null|null|null|null|null| 18|null|
|2005-01-01 00:00:00|null|null|null|null|null| 15|null|null|null|null|
|2000-01-01 00:00:00| 10|null|null|null|null|null|null|null|null|null|
|2007-01-01 00:00:00|null|null|null|null|null|null|null| 17|null|null|
|2002-01-01 00:00:00|null|null| 12|null|null|null|null|null|null|null|
+-------------------+----+----+----+----+----+----+----+----+----+----+
2. Stack
from itertools import chain
df_stack = df_unstack.select(
'date',
psf.explode(psf.create_map(
list(chain(*[(psf.lit(c), psf.col(c)) for c in df_unstack.columns if c != "date"]))
)).alias("a", "b"))\
.filter(~psf.isnull("b"))
df_stack.show()
+-------------------+---+---+
| date| a| b|
+-------------------+---+---+
|2003-01-01 00:00:00| 3| 13|
|2004-01-01 00:00:00| 4| 14|
|2009-01-01 00:00:00| 9| 19|
|2001-01-01 00:00:00| 1| 11|
|2006-01-01 00:00:00| 6| 16|
|2008-01-01 00:00:00| 8| 18|
|2005-01-01 00:00:00| 5| 15|
|2000-01-01 00:00:00| 0| 10|
|2007-01-01 00:00:00| 7| 17|
|2002-01-01 00:00:00| 2| 12|
+-------------------+---+---+
If the last snippet seems a bit laborious, this is because explode
is actually meant for ArrayType
or MapType
columns and not for a list of separate columns. It is more the opposite of a .groupBy().agg(psf.collect_set)
than a pivot
.
3. Cartesian product, self join, or cross join
What you are looking for is not a way to stack, unstack (expand, explode...) a dataframe but to do a cartesian product on itself. In pyspark
, you can use crossJoin
(if spark >= 2 else join
without any joining key).
left = df.alias('left')
right = df.alias('right')
df_cross = df_left.crossJoin(df_right) \
.select(
psf.col('left.date'),
psf.col('right.a'),
psf.when(psf.col('left.date') == psf.col('right.date'), left.b).otherwise(None).alias('b'))
df_cross.sort('date', 'a').show()
+-------------------+---+----+
| date| a| b|
+-------------------+---+----+
|2000-01-01 00:00:00| 0| 10|
|2000-01-01 00:00:00| 1|null|
|2000-01-01 00:00:00| 2|null|
|2000-01-01 00:00:00| 3|null|
|2000-01-01 00:00:00| 4|null|
|2000-01-01 00:00:00| 5|null|
|2000-01-01 00:00:00| 6|null|
|2000-01-01 00:00:00| 7|null|
|2000-01-01 00:00:00| 8|null|
|2000-01-01 00:00:00| 9|null|
|2001-01-01 00:00:00| 0|null|
|2001-01-01 00:00:00| 1| 11|
|2001-01-01 00:00:00| 2|null|
|2001-01-01 00:00:00| 3|null|
|2001-01-01 00:00:00| 4|null|
|2001-01-01 00:00:00| 5|null|
|2001-01-01 00:00:00| 6|null|
|2001-01-01 00:00:00| 7|null|
|2001-01-01 00:00:00| 8|null|
|2001-01-01 00:00:00| 9|null|
+-------------------+---+----+
This is the pandas
equivalent of a merge of a data frame on itself:
import numpy as np
df_pd = df.toPandas()
df_pd.loc[:, 'key'] = 1
df_pd_cross = df_pd.merge(df_pd, on='key')
df_pd_cross = df_pd_cross \
.assign(b=np.where(df_pd_cross.date_x==df_pd_cross.date_y, df_pd_cross.b_x, None)) \
.rename(columns={'date_x': 'date', 'a_y': 'a'})[['date', 'a', 'b']]
df_pd_cross.sort_values(['date', 'a']).head(20)
+-----+-------------+----+------+
| | date | a | b |
+-----+-------------+----+------+
| 0 | 2000-01-01 | 0 | 10 |
| 1 | 2000-01-01 | 1 | None |
| 2 | 2000-01-01 | 2 | None |
| 3 | 2000-01-01 | 3 | None |
| 4 | 2000-01-01 | 4 | None |
| 5 | 2000-01-01 | 5 | None |
| 6 | 2000-01-01 | 6 | None |
| 7 | 2000-01-01 | 7 | None |
| 8 | 2000-01-01 | 8 | None |
| 9 | 2000-01-01 | 9 | None |
| 10 | 2001-01-01 | 0 | None |
| 11 | 2001-01-01 | 1 | 11 |
| 12 | 2001-01-01 | 2 | None |
| 13 | 2001-01-01 | 3 | None |
| 14 | 2001-01-01 | 4 | None |
| 15 | 2001-01-01 | 5 | None |
| 16 | 2001-01-01 | 6 | None |
| 17 | 2001-01-01 | 7 | None |
| 18 | 2001-01-01 | 8 | None |
| 19 | 2001-01-01 | 9 | None |
+-----+-------------+----+------+