0

I have a spark dataframe in the below format where each unique id can have maximum of 3 rows which is given by rank column.

 id pred    prob      rank
485 9716    0.19205872  1
729 9767    0.19610429  1
729 9716    0.186840048 2
729 9748    0.173447074 3
818 9731    0.255104463 1
818 9748    0.215499913 2
818 9716    0.207307154 3

I want to convert (cast) into a row wise data such that each id has just one row and the pred & prob column have multiple columns differentiated by rank variable( column postfix).

id  pred_1  prob_1      pred_2  prob_2     pred_3   prob_3
485 9716    0.19205872              
729 9767    0.19610429  9716    0.186840048 9748    0.173447074
818 9731    0.255104463 9748    0.215499913 9716    0.207307154

I am not able to figure out how to o it in Pyspark

Sample code for input data creation:

# Loading the requisite packages 
from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit        
# Creating the DataFrame
df = sqlContext.createDataFrame([(485,9716,19,1),(729,9767,19,1),(729,9716,18,2), (729,9748,17,3), (818,9731,25,1), (818,9748,21,2), (818,9716,20,3)],('id','pred','prob','rank'))
df.show()
Deb
  • 499
  • 2
  • 15
  • 1
    possible duplicate of [this](https://stackoverflow.com/questions/45035940/how-to-pivot-on-multiple-columns-in-spark-sql) question, so please have a look. – sophocles Oct 21 '21 at 07:47

1 Answers1

2

This is the pivot on multiple columns problem.Try:

import pyspark.sql.functions as F

df_pivot = df.groupBy('id').pivot('rank').agg(F.first('pred').alias('pred'), F.first('prob').alias('prob')).orderBy('id')
df_pivot.show(truncate=False)
sophocles
  • 13,593
  • 3
  • 14
  • 33
过过招
  • 3,722
  • 2
  • 4
  • 11
  • @ 过过招 I am getting an error as " NameError: name 'F' is not defined" – Deb Oct 21 '21 at 07:51
  • 1
    You need to import the functions with an alias F. add to the above code in the position ```import pyspark.sql.functions as F``` – sophocles Oct 21 '21 at 07:52
  • Sorry, import is omitted. ```import pyspark.sql.types as T``` – 过过招 Oct 21 '21 at 07:58
  • 1
    The field names are separated by ```_```, and you can reverse them.```df_col_rename = df_pivot.select([F.col(c).alias('_'.join(x for x in c.split('_')[::-1])) for c in df_pivot.columns])``` – 过过招 Oct 21 '21 at 08:19