23

I am a novice to spark, and I want to transform below source dataframe (load from JSON file):

+--+-----+-----+
|A |count|major|
+--+-----+-----+
| a|    1|   m1|
| a|    1|   m2|
| a|    2|   m3|
| a|    3|   m4|
| b|    4|   m1|
| b|    1|   m2|
| b|    2|   m3|
| c|    3|   m1|
| c|    4|   m3|
| c|    5|   m4|
| d|    6|   m1|
| d|    1|   m2|
| d|    2|   m3|
| d|    3|   m4|
| d|    4|   m5|
| e|    4|   m1|
| e|    5|   m2|
| e|    1|   m3|
| e|    1|   m4|
| e|    1|   m5|
+--+-----+-----+

Into below result dataframe:

+--+--+--+--+--+--+
|A |m1|m2|m3|m4|m5|
+--+--+--+--+--+--+
| a| 1| 1| 2| 3| 0|
| b| 4| 2| 1| 0| 0|
| c| 3| 0| 4| 5| 0|
| d| 6| 1| 2| 3| 4|
| e| 4| 5| 1| 1| 1|
+--+--+--+--+--+--+

Here is the Transformation Rule:

  1. The result dataframe is consisted with A + (n major columns) where the major columns names are specified by:

    sorted(src_df.map(lambda x: x[2]).distinct().collect())
    
  2. The result dataframe contains m rows where the values for A column are provided by:

    sorted(src_df.map(lambda x: x[0]).distinct().collect())
    
  3. The value for each major column in result dataframe is the value from source dataframe on the corresponding A and major (e.g. the count in Row 1 in source dataframe is mapped to the box where A is a and column m1)

  4. The combinations of A and major in source dataframe has no duplication (please consider it a primary key on the two columns in SQL)

Community
  • 1
  • 1
resec
  • 2,091
  • 3
  • 13
  • 22

4 Answers4

25

Using zero323's dataframe,

df = sqlContext.createDataFrame([
("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
("e", 1, "m4"), ("e", 1, "m5")], 
("a", "cnt", "major"))

you could also use

reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)
TrentWoodbury
  • 871
  • 2
  • 13
  • 22
17

Lets start with example data:

df = sqlContext.createDataFrame([
    ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
    ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
    ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
    ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
    ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
    ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
    ("e", 1, "m4"), ("e", 1, "m5")], 
    ("a", "cnt", "major"))

Please note that I've changed count to cnt. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.

There are at least two ways to reshape this data:

  • aggregating over DataFrame

    from pyspark.sql.functions import col, when, max
    
    majors = sorted(df.select("major")
        .distinct()
        .map(lambda row: row[0])
        .collect())
    
    cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m) 
        for m in  majors]
    maxs = [max(col(m)).alias(m) for m in majors]
    
    reshaped1 = (df
        .select(col("a"), *cols)
        .groupBy("a")
        .agg(*maxs)
        .na.fill(0))
    
    reshaped1.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## |  e|  4|  5|  1|  1|  1|
    ## +---+---+---+---+---+---+
    
  • groupBy over RDD

    from pyspark.sql import Row
    
    grouped = (df
        .map(lambda row: (row.a, (row.major, row.cnt)))
        .groupByKey())
    
    def make_row(kv):
        k, vs = kv
        tmp = dict(list(vs) + [("a", k)])
        return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors})
    
    reshaped2 = sqlContext.createDataFrame(grouped.map(make_row))
    
    reshaped2.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  e|  4|  5|  1|  1|  1|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## +---+---+---+---+---+---+
    
zero323
  • 322,348
  • 103
  • 959
  • 935
0

This is your original dataframe:

df.show()
+--+-----+-----+
|A |count|major|
+--+-----+-----+
| a|    1|   m1|
| a|    1|   m2|
| a|    2|   m3|
| a|    3|   m4|
| b|    4|   m1|
| b|    1|   m2|
| b|    2|   m3|
| c|    3|   m1|
| c|    4|   m3|
| c|    5|   m4|
| d|    6|   m1|
| d|    1|   m2|
| d|    2|   m3|
| d|    3|   m4|
| d|    4|   m5|
| e|    4|   m1|
| e|    5|   m2|
| e|    1|   m3|
| e|    1|   m4|
| e|    1|   m5|
+--+-----+-----+

Using pivot to reshape the data on "major", grouped by "A", and sum of "count" aggregated as value:

data = ( df.groupBy("A")
    .pivot("major")
    .sum("count") )

display(data)
+--+--+--+--+--+--+
|A |m1|m2|m3|m4|m5|
+--+--+--+--+--+--+
| a| 1| 1| 2| 3| 0|
| b| 4| 2| 1| 0| 0|
| c| 3| 0| 4| 5| 0|
| d| 6| 1| 2| 3| 4|
| e| 4| 5| 1| 1| 1|
+--+--+--+--+--+--+
0

@TrentWoodbury's Answer is a good one and I voted for it. However, it doesn't work if the aggregate value is not an a number because .max('cnt') cannot perform the required aggregation. Also, max is potentially slower, or just unneeded, if you know you won't have duplicate values.

The following works for all data types:

from pyspark.sql.functions import first

reshaped_df = df.groupby('a') \
  .pivot('major') \ 
  .agg(first('cnt'))
Joshcodes
  • 8,513
  • 5
  • 40
  • 47