0

I have dataframe like below

|123  |124  |125 |
+-----+-----+-----+
|    1|    2|    3|
|    9|    9|    4|
|    4|   12|    1|
|    2|    4|    8|
|    7|    6|    3|
|   19|   11|    2|
|   21|   10|   10

i need the data to be in

1:[123,125]
2:[123,124,125]
3:[125]

Order is not required to be sorted . I am new to dataframes in pyspark any help would be appreciated

Anonymous
  • 11
  • 3
  • 1
    Hi @Sankar and Welcome to SO. What's the logic behind going from the given input DF to the expected output? The question isn't very clear. Please see [How to make good reproducible Apache Spark examples](https://stackoverflow.com/q/48427185/1386551). – blackbishop Feb 26 '20 at 15:47

1 Answers1

1

There are no melt or pivot APIs in pyspark that will accomplish this directly. Instead, flatmap from the RDD into a new dataframe and aggregate:

df.show()                                                                                                                                                                                           

+---+---+---+
|123|124|125|
+---+---+---+
|  1|  2|  3|
|  9|  9|  4|
|  4| 12|  1|
|  2|  4|  8|
|  7|  6|  3|
| 19| 11|  2|
| 21| 10| 10|
+---+---+---+

For each column or each row in the RDD, output a row with two columns: the value of the column and the column name:

cols = df.columns
(df.rdd
 .flatMap(lambda row: [(row[c], c) for c in cols]).toDF(["value", "column_name"])
 .show())

+-----+-----------+
|value|column_name|
+-----+-----------+
|    1|        123|
|    2|        124|
|    3|        125|
|    9|        123|
|    9|        124|
|    4|        125|
|    4|        123|
|   12|        124|
|    1|        125|
|    2|        123|
|    4|        124|
|    8|        125|
|    7|        123|
|    6|        124|
|    3|        125|
|   19|        123|
|   11|        124|
|    2|        125|
|   21|        123|
|   10|        124|
+-----+-----------+

Then, group by the value and aggregate the column names into a list:

from pyspark.sql import functions as f 

(df.rdd
 .flatMap(lambda row: [(row[c], c) for c in cols]).toDF(["value", "column_name"])
 .groupby("value").agg(f.collect_list("column_name"))
 .show())

+-----+-------------------------+
|value|collect_list(column_name)|
+-----+-------------------------+
|   19|                    [123]|
|    7|                    [123]|
|    6|                    [124]|
|    9|               [123, 124]|
|    1|               [123, 125]|
|   10|               [124, 125]|
|    3|               [125, 125]|
|   12|                    [124]|
|    8|                    [125]|
|   11|                    [124]|
|    2|          [124, 123, 125]|
|    4|          [125, 123, 124]|
|   21|                    [123]|
+-----+-------------------------+
Dave
  • 1,579
  • 14
  • 28