I have a spark dataframe looks like this:
from pyspark.sql import SQLContext, Row
sqlContext = SQLContext(sc)
from pyspark.sql.types import StringType, IntegerType, StructType, StructField,LongType
from pyspark.sql.functions import sum, mean
rdd = sc.parallelize([('retail','food'),
('retail','food'),
('retail','auto'),
('retail','shoes'),
('wholesale','healthsupply'),
('wholesale','foodsupply'),
('wholesale','foodsupply'),
('retail','toy'),
('retail','toy'),
('wholesale','foodsupply'])
schema = StructType([StructField('division', StringType(), True),
StructField('category', StringType(), True)
])
df = sqlContext.createDataFrame(rdd, schema)
I want to generate a table like this, get the division name, division totol records number, top 1 and top2 category within each division and their record number:
division division_total cat_top_1 top1_cnt cat_top_2 top2_cnt
retail 5 food 2 toy 2
wholesale4 foodsupply 3 healthsupply 1
Now I could generate the cat_top_1, cat_top_2 by using window functions in spark, but how to pivot to row, also add a column of division_total, I could not do it right
df_by_div = df.groupby('division','revenue').sort(asc("division"),desc("count"))
windowSpec = Window().partitionBy("division").orderBy(col("count").desc())
df_list = df_by_div.withColumn("rn", rowNumber()\
.over(windowSpec).cast('int'))\
.where(col("rn")<=2)\
.orderBy("division",desc("count"))