18

Using Scala, how can I split dataFrame into multiple dataFrame (be it array or collection) with same column value. For example I want to split the following DataFrame:

ID  Rate    State
1   24  AL
2   35  MN
3   46  FL
4   34  AL
5   78  MN
6   99  FL

to:

data set 1

ID  Rate    State
1   24  AL  
4   34  AL

data set 2

ID  Rate    State
2   35  MN
5   78  MN

data set 3

ID  Rate    State
3   46  FL
6   99  FL
Jacek Laskowski
  • 72,696
  • 27
  • 242
  • 420
user1735076
  • 3,225
  • 7
  • 19
  • 16

3 Answers3

23

You can collect unique state values and simply map over resulting array:

val states = df.select("State").distinct.collect.flatMap(_.toSeq)
val byStateArray = states.map(state => df.where($"State" <=> state))

or to map:

val byStateMap = states
    .map(state => (state -> df.where($"State" <=> state)))
    .toMap

The same thing in Python:

from itertools import chain
from pyspark.sql.functions import col

states = chain(*df.select("state").distinct().collect())

# PySpark 2.3 and later
# In 2.2 and before col("state") == state) 
# should give the same outcome, ignoring NULLs 
# if NULLs are important 
# (lit(state).isNull() & col("state").isNull()) | (col("state") == state)
df_by_state = {state: 
  df.where(col("state").eqNullSafe(state)) for state in states}

The obvious problem here is that it requires a full data scan for each level, so it is an expensive operation. If you're looking for a way to just split the output see also How do I split an RDD into two or more RDDs?

In particular you can write Dataset partitioned by the column of interest:

val path: String = ???
df.write.partitionBy("State").parquet(path)

and read back if needed:

// Depend on partition prunning
for { state <- states } yield spark.read.parquet(path).where($"State" === state)

// or explicitly read the partition
for { state <- states } yield spark.read.parquet(s"$path/State=$state")

Depending on the size of the data, number of levels of the splitting, storag and persistence level of the input it might faster or slower than multiple filters.

Community
  • 1
  • 1
zero323
  • 322,348
  • 103
  • 959
  • 935
  • 1
    Maybe Kind of late question. But when I try the python Code in Spark 2.2.0 I always get a "Column is not callable" error. I tried several approaches but still I get the same error. Any Workarounds for this? – inneb Oct 08 '17 at 11:16
  • 1
    you need to import `col` with `from pyspark.sql.functions import col ` – Luis Mar 28 '18 at 16:04
2

It is very simple (if the spark version is 2) if you make the dataframe as a temporary table.

df1.createOrReplaceTempView("df1")

And now you can do the queries,

var df2 = spark.sql("select * from df1 where state = 'FL'")
var df3 = spark.sql("select * from df1 where state = 'MN'")
var df4 = spark.sql("select * from df1 where state = 'AL'")

Now you got the df2, df3, df4. If you want to have them as list, you can use,

df2.collect()
df3.collect()

or even map/filter function. Please refer https://spark.apache.org/docs/latest/sql-programming-guide.html#datasets-and-dataframes

Ash

ashK
  • 713
  • 2
  • 11
  • 24
  • is there a possibility to loop SQL queries in spark? Collecting all distinct values before and then replacing the "where state = 'FL'" with "where state = 'i'" or something like this? – inneb Oct 08 '17 at 11:11
  • It will be overhead but still you can handle it using Spark Dataframes and SCALA coding – ashK Nov 10 '17 at 13:38
  • I used the same to split a DF into 5 sub-DF for doing left joins, the resultant DF is a view and not an independent DF on its own, its messing with left joins, can I split into independent DF ?? – Sandeep540 Oct 14 '19 at 16:30
0

you can use ..

var stateDF = df.select("state").distinct()  // to get states in a df
val states = stateDF.rdd.map(x=>x(0)).collect.toList //to get states in a list

for (i <- states)  //loop to get each state
{
    var finalDF = sqlContext.sql("select * from table1 where state = '" + state
    +"' ")
}
The Guy with The Hat
  • 10,836
  • 8
  • 57
  • 75