0

I am trying to find the sum of value column based on code column in a table in the below format using pyspark. Here in this example, I have provided only 3 code and its respective value columns, but in real scenario it can be up to 100.

Table A

id1 item code_1 Value_1 code_2 Value_2 code_3 value_3
100 1 A 5 X 10 L 20
100 2 B 5 L 10 A 20

Expected output:

id1 item sum_A sum_X sum_L sum B Total
100 1 25 10 30 5 70
100 2 25 10 30 5 70

Can someone help me to find a logic to accomplish this output.

Amrutha K
  • 15
  • 3

2 Answers2

1

You can use Stack + groupBy + Pivot functions for this case.

Example:

df.show()
df1 = df.select("id1",'item',expr("stack(3,code_1,value_1,code_2,value_2,code_3,value_3)")).\
  groupBy("id1","col0").agg(sum("col1").alias("sum")).\
    withColumn("col0",concat(lit("sum_"),col("col0")))
df.select("id1","item").distinct().\
  join(df1,['id1']).\
    groupBy("id1","item").\
      pivot("col0").\
        agg(first(col("sum").alias("sum_"))).\
        show()

Output:

#sample data
+---+----+------+-------+------+-------+------+-------+
|id1|item|code_1|value_1|code_2|value_2|code_3|value_3|
+---+----+------+-------+------+-------+------+-------+
|100|   1|     A|      5|     X|     10|     L|     20|
|100|   2|     B|      5|     L|     10|     A|     20|
+---+----+------+-------+------+-------+------+-------+

#output    
+---+----+-----+-----+-----+-----+
|id1|item|sum_A|sum_B|sum_L|sum_X|
+---+----+-----+-----+-----+-----+
|100|   2|   25|    5|   30|   10|
|100|   1|   25|    5|   30|   10|
+---+----+-----+-----+-----+-----+

UPDATE:

Dynamic sql:

req_cols = [c for c in df.columns if c.startswith("code_") or c.startswith("value_")]

sql_expr = "stack("+ str(int(len(req_cols)/2))+"," +','.join(req_cols) +")"

df.show()
df1 = df.select("id1",'item',expr(f"{sql_expr}")).\
  groupBy("id1","col0").agg(sum("col1").alias("sum")).\
    withColumn("col0",concat(lit("sum_"),col("col0")))
df.select("id1","item").distinct().\
  join(df1,['id1']).\
    groupBy("id1","item").\
      pivot("col0").\
        agg(first(col("sum").alias("sum_"))).\
        show()
notNull
  • 30,258
  • 4
  • 35
  • 50
0

alternative approach using window function.

The first part is mostly same as @notNull.

from pyspark.sql import functions as F
n = 3
melt_cols = ','.join([f'code_{x}, Value_{x}' for x in range(1, n+1)])

df = (df.select('id1', 'item', 
                F.expr(f'stack({n}, {melt_cols}) AS (code, value)'))
      .groupby('id1', 'item')
      .pivot('code')
      .agg(F.sum('value')))

This will result in aggregate per id1 & item.

+---+----+---+----+---+----+
|id1|item|  A|   B|  L|   X|
+---+----+---+----+---+----+
|100|   2| 20|   5| 10|null|
|100|   1|  5|null| 20|  10|
+---+----+---+----+---+----+

Then, use the window function to aggregate per id1.

df = df.select('id1', 'item', 
               *[F.sum(x).over(Window.partitionBy('id1')).alias(f'sum_{x}') 
                 for x in df.columns if x not in ['id1', 'item']])

Final result

+---+----+-----+-----+-----+-----+
|id1|item|sum_A|sum_B|sum_L|sum_X|
+---+----+-----+-----+-----+-----+
|100|   2|   25|    5|   30|   10|
|100|   1|   25|    5|   30|   10|
+---+----+-----+-----+-----+-----+
Emma
  • 8,518
  • 1
  • 18
  • 35