0

Assume I have a Pyspark dataframe as shown below. Each user bought one item on some specific date.

+--+-------------+-----------+
|ID|  Item Bought| Date      |
+--+-------------+-----------+
|1 |  Laptop     | 01/01/2018|  
|1 |  Laptop     | 12/01/2017|  
|1 |  Car        | 01/12/2018|  
|2 |  Cake       | 02/01/2018|  
|3 |  TV         | 11/02/2017| 
+--+-------------+-----------+

Now I would like to create a new data frame as shown below.

+---+--------+-----+------+----+
|ID | Laptop | Car | Cake | TV |
+---+--------+-----+------+----+
|1  | 2      | 1   | 0    | 0  | 
|2  | 0      | 0   | 1    | 0  |
|3  | 0      | 0   | 0    | 1  |
+---+--------+-----+------+----+

There are item columns, each column for one item. For each user, the number on each column is the number of that items user bought.

pault
  • 41,343
  • 15
  • 107
  • 149

2 Answers2

2

If you have data in pyspark as a dataframe like this

df = sc.parallelize(([(1, 'laptop', '01/01/2018'),
                    (1, 'laptop', '12/01/2017'),
                    (1, 'car', '01/12/2018'),
                    (2, 'cake', '02/01/2018'),
                    (3, 'tv', '11/02/2017')])).toDF(['id', 'item bought', 'date'])

Now, you can use groupby and pivot operations to get the result.

df2 = (df.groupby(['id']).pivot('item bought', ['tv','cake', 'laptop',"car"]).
                count().fillna(0).show())
df2.show()

result

+---+---+----+------+---+
| id| tv|cake|laptop|car|
+---+---+----+------+---+
|  1|  0|   0|     2|  1|
|  3|  1|   0|     0|  0|
|  2|  0|   1|     0|  0|
+---+---+----+------+---+

Remember in pivot operation it is not necessary to supply the distinct values but supplying those values will speed up the process.

pault
  • 41,343
  • 15
  • 107
  • 149
pauli
  • 4,191
  • 2
  • 25
  • 41
0

Another solution,

import pyspark.sql.functions as F
df = sc.parallelize([
(1,'Laptop','01/01/2018'), (1,  'Laptop','12/01/2017'),(1,'Car','01/12/2018'),
(2 ,'Cake', '02/01/2018'),(3,'TV','11/02/2017')]).toDF(['ID','Item','Date'])


items = sorted(df.select("Item").distinct().rdd\
           .map(lambda row: row[0])\
           .collect())

cols = [F.when(F.col("Item") == m, F.col("Item")).otherwise(None).alias(m) for m in items]
counts = [F.count(F.col(m)).alias(m) for m in items]

df_reshaped = df.select(F.col("ID"), *cols)\
                .groupBy("ID")\
                .agg(*counts)
df_reshaped.show()
pault
  • 41,343
  • 15
  • 107
  • 149
mayank agrawal
  • 2,495
  • 2
  • 13
  • 32