14

Lets assume that original data is like:

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp1       B       £11       £16
Comp1       C       £11       £15
Comp2       A       £9        £16
Comp2       B       £12       £14
Comp2       C       £14       £17
Comp3       A       £11       £16
Comp3       B       £10       £15
Comp3       C       £12       £15

(Ref: Python - splitting dataframe into multiple dataframes based on column values and naming them with those values)

I wish to get list of sub dataframes based on column values, say Region, like:

df_A :

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp2       A       £9        £16
Comp3       A       £11       £16

In Python I could do:

for region, df_region in df.groupby('Region'):
    print(df_region)

Can I do same iteration if the df is Pyspark df?

In Pyspark, once I do df.groupBy("Region") I get GroupedData. I dont need any aggregation like count, mean, etc. I just need list of sub dataframes, each have same "Region" value. Possible?

Florian
  • 24,425
  • 4
  • 49
  • 80

2 Answers2

14

The approach below should work for you, under the assumption that the list of unique values in the grouping column is small enough to fit in memory on the driver. Hope this helps!

import pyspark.sql.functions as F
import pandas as pd

# Sample data 
df = pd.DataFrame({'region': ['aa','aa','aa','bb','bb','cc'],
                   'x2': [6,5,4,3,2,1],
                   'x3': [1,2,3,4,5,6]})
df = spark.createDataFrame(df)

# Get unique values in the grouping column
groups = [x[0] for x in df.select("region").distinct().collect()]

# Create a filtered DataFrame for each group in a list comprehension
groups_list = [df.filter(F.col('region')==x) for x in groups]

# show the results
[x.show() for x in groups_list]

Result:

+------+---+---+
|region| x2| x3|
+------+---+---+
|    cc|  1|  6|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    bb|  3|  4|
|    bb|  2|  5|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    aa|  6|  1|
|    aa|  5|  2|
|    aa|  4|  3|
+------+---+---+
Florian
  • 24,425
  • 4
  • 49
  • 80
0

Needed the name of the grouping also, so I put it in an array as the first element.

valuesA = [('Pirate',1),('Monkey',2),('Ninja',3),('Spaghetti',4),('Pirate',5)]
TableA = sqlContext.createDataFrame(valuesA,['name','id'])

valuesB = [('Pirate',1),('Rutabaga',2),('Ninja',3),('Darth Vader',4),('Pirate',5)]
TableB = sqlContext.createDataFrame(valuesB,['name','id'])

TableA.show()
TableB.show()

ta = TableA.alias('ta')
tb = TableB.alias('tb')

df = ta.join(tb, (ta.name == tb.name) & (ta.id == tb.id),how='full') # Could also use 'full_outer'
df.show()

# Get unique values in the grouping column
groups = [x[0] for x in df.select("ta.name").distinct().collect()]

# Create a filtered DataFrame for each group in a list comprehension
groups_list = [[x,df.filter(F.col('ta.name')==x)] for x in groups]

# show the results
for x,dfx in groups_list:  
    print(x)  
    dfx.show() 

None
+----+---+----+---+
|name| id|name| id|
+----+---+----+---+
+----+---+----+---+

Spaghetti
+---------+---+----+----+
|     name| id|name|  id|
+---------+---+----+----+
|Spaghetti|  4|null|null|
+---------+---+----+----+

Ninja
+-----+---+-----+---+
| name| id| name| id|
+-----+---+-----+---+
|Ninja|  3|Ninja|  3|
+-----+---+-----+---+

Pirate
+------+---+------+---+
|  name| id|  name| id|
+------+---+------+---+
|Pirate|  1|Pirate|  1|
|Pirate|  5|Pirate|  5|
+------+---+------+---+

Monkey
+------+---+----+----+
|  name| id|name|  id|
+------+---+----+----+
|Monkey|  2|null|null|
+------+---+----+----+
ewittry
  • 81
  • 1
  • 2