0

I am converting sql code into Pyspark.

The sql code is using rollup to sum up the count for each state.

I am try to do the same thing in pyspark, but don't know how to get the total count row.

I have a table with state, city, and count, I want to add a total count for each state at the end of the state sections.

This is a sample input:

State   City      Count
WA      Seattle    10
WA      Tacoma     11
MA      Boston     11
MA      Cambridge  3
MA      Quincy     5

This is my desired output:

State   City       Count
 WA     Seattle    10
 WA     Tacoma     11
 WA     Total      21
 MA     Boston     11
 MA     Cambridge  3
 MA     Quincy     5
 MA     Total      19

I don't know how to add the total count in between states.

I did try rollup, here is my code:

df2=df.rollup('STATE').count()

and the result show up like this:

State  Count
 WA     21
 MA     19

But I want the Total after each state.

pault
  • 41,343
  • 15
  • 107
  • 149
yokielove
  • 213
  • 1
  • 10
  • Roolup is supported in Spark both in SQL and `DataFrame` API (https://stackoverflow.com/q/37975227/9613318). Did you experience any problems? – Alper t. Turker Apr 11 '18 at 20:03
  • Yes, I did try rollup, but it end up with only state with total count. i want the total count row in between states – yokielove Apr 11 '18 at 20:19
  • @yokielove it would be helpful if you could share the sql code that you referenced. – pault Apr 11 '18 at 21:26

2 Answers2

2

Since you want the Total as a new row inside your DataFrame, one option is to union the results of the groupBy() and sort by ["State", "City", "Count"] (to ensure that the "Total" row displays last in each group):

import pyspark.sql.functions as f
df.union(
    df.groupBy("State")\
    .agg(f.sum("Count").alias("Count"))\
    .select("State", f.lit("Total").alias("City"), "Count")
).sort("State", "City", "Count").show()
#+-----+---------+-----+
#|State|     City|Count|
#+-----+---------+-----+
#|   MA|   Boston|   11|
#|   MA|Cambridge|    3|
#|   MA|   Quincy|    5|
#|   MA|    Total|   19|
#|   WA|  Seattle|   10|
#|   WA|   Tacoma|   11|
#|   WA|    Total|   21|
#+-----+---------+-----+
pault
  • 41,343
  • 15
  • 107
  • 149
-2

Either:

df.groubpBy("State", "City").rollup(count("*"))

or just register table:

df.createOrReplaceTempView("df")

and apply your current SQL query with

spark.sql("...")