I have added few extra rows to your sample data to differentiate aggregation. I have used scala parallel collection, For each country it will get states & then uses those values to filter the given dataframe & then do aggregation, end it will join all the result back.
scala> val df = Seq(
| ("Indus_1","Indus_1_Name","Country1", "State1",12789979),
| ("Indus_2","Indus_2_Name","Country1", "State2",21789933),
| ("Indus_2","Indus_2_Name","Country1", "State2",31789933),
| ("Indus_3","Indus_3_Name","Country1", "State3",21789978),
| ("Indus_4","Indus_4_Name","Country2", "State1",41789978),
| ("Indus_4","Indus_4_Name","Country2", "State2",41789978),
| ("Indus_4","Indus_4_Name","Country2", "State2",81789978),
| ("Indus_4","Indus_4_Name","Country2", "State3",41789978),
| ("Indus_4","Indus_4_Name","Country2", "State3",51789978),
| ("Indus_5","Indus_5_Name","Country3", "State3",27789978),
| ("Indus_6","Indus_6_Name","Country1", "State1",27899790),
| ("Indus_7","Indus_7_Name","Country3", "State1",27899790),
| ("Indus_8","Indus_8_Name","Country1", "State2",27899790),
| ("Indus_9","Indus_9_Name","Country4", "State1",27899790)
| ).toDF("industry_id","industry_name","country","state","revenue")
df: org.apache.spark.sql.DataFrame = [industry_id: string, industry_name: string ... 3 more fields]
scala> val countryList = Seq("Country1","Country2","Country4","Country5");
countryList: Seq[String] = List(Country1, Country2, Country4, Country5)
scala> val stateMap = Map("Country1" -> ("State1","State2"), "Country2" -> ("State2","State3"),"Country3" -> ("State31","State32"));
stateMap: scala.collection.immutable.Map[String,(String, String)] = Map(Country1 -> (State1,State2), Country2 -> (State2,State3), Country3 -> (State31,State32))
scala>
scala> :paste
// Entering paste mode (ctrl-D to finish)
countryList
.par
.filter(cn => stateMap.exists(_._1 == cn))
.map(country => (country,stateMap(country)))
.map{data =>
df.filter($"country" === data._1 && ($"state" === data._2._1 || $"state" === data._2._2)).groupBy("country","state","industry_name").agg(sum("revenue").as("total_revenue"))
}.reduce(_ union _).show(false)
// Exiting paste mode, now interpreting.
+--------+------+-------------+-------------+
|country |state |industry_name|total_revenue|
+--------+------+-------------+-------------+
|Country1|State2|Indus_8_Name |27899790 |
|Country1|State1|Indus_6_Name |27899790 |
|Country1|State2|Indus_2_Name |53579866 |
|Country1|State1|Indus_1_Name |12789979 |
|Country2|State3|Indus_4_Name |93579956 |
|Country2|State2|Indus_4_Name |123579956 |
+--------+------+-------------+-------------+
scala>
Edit - 1 : Separated Agg code into different function block.
scala> def processDF(data:(String,(String,String)),adf:DataFrame) = adf.filter($"country" === data._1 && ($"state" === data._2._1 || $"state" === data._2._2)).groupBy("country","state","industry_name").agg(sum("revenue").as("total_revenue"))
processDF: (data: (String, (String, String)), adf: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame
scala> :paste
// Entering paste mode (ctrl-D to finish)
countryList.
par
.filter(cn => stateMap.exists(_._1 == cn))
.map(country => (country,stateMap(country)))
.map(data => processDF(data,df))
.reduce(_ union _)
.show(false)
// Exiting paste mode, now interpreting.
+--------+------+-------------+-------------+
|country |state |industry_name|total_revenue|
+--------+------+-------------+-------------+
|Country1|State2|Indus_8_Name |27899790 |
|Country1|State1|Indus_6_Name |27899790 |
|Country1|State2|Indus_2_Name |53579866 |
|Country1|State1|Indus_1_Name |12789979 |
|Country2|State3|Indus_4_Name |93579956 |
|Country2|State2|Indus_4_Name |123579956 |
+--------+------+-------------+-------------+
scala>