0

I have a complicated winodwing operation which I need help with in pyspark.

I have some data grouped by src and dest, and I need to do the following operations for each group: - select only rows with amounts in socket2 which do not appear in socket1 (for all rows in this group) - after applying that filtering criteria, sum amounts in amounts field

amounts     src    dest    socket1   socket2
10          1        2           A       B
11          1        2           B        C
12           1        2          C       D
510          1       2          C       D
550          1        2          B       C  
500          1        2          A       B
80            1         3          A        B

And I want to aggregate it in the following way:
512+10 = 522, and 80 is the only record for src=1 and dest=3

amounts     src    dest    
522          1        2      
80          1        3    

I borrowed the sample data from here: How to write Pyspark UDAF on multiple columns?

makansij
  • 9,303
  • 37
  • 105
  • 183

1 Answers1

3

You can split your dataframe into 2 dataframes one with socket1 and the other one with socket2and then use a leftanti join instead of filtering (works for spark >= 2.0).

First let's create the dataframe:

df = spark.createDataFrame(
    sc.parallelize([
        [10,1,2,"A","B"],
        [11,1,2,"B","C"],
        [12,1,2,"C","D"],
        [510,1,2,"C","D"],
        [550,1,2,"B","C"],
        [500,1,2,"A","B"],
        [80,1,3,"A","B"]
    ]), 
    ["amounts","src","dest","socket1","socket2"]
)

And now to split the dataframe :

Spark >= 2.0

df1 = df.withColumnRenamed("socket1", "socket").drop("socket2")
df2 = df.withColumnRenamed("socket2", "socket").drop("socket1")
res = df2.join(df1, ["src", "dest", "socket"], "leftanti")

Spark 1.6

df1 = df.withColumnRenamed("socket1", "socket").drop("socket2").withColumnRenamed("amounts", "amounts1")
df2 = df.withColumnRenamed("socket2", "socket").drop("socket1")
res = df2.join(df1.alias("df1"), ["src", "dest", "socket"], "left").filter("amounts1 IS NULL").drop("amounts1")

And finally the aggregation:

import pyspark.sql.functions as psf
res.groupBy("src", "dest").agg(
    psf.sum("amounts").alias("amounts")
).show()

    +---+----+-------+
    |src|dest|amounts|
    +---+----+-------+
    |  1|   3|     80|
    |  1|   2|    522|
    +---+----+-------+
MaFF
  • 9,551
  • 2
  • 32
  • 41