Utkarsh , what are trying to do is called Stratified Sampling in spark. There are direct methods available for sampling. Keyed sampling is possible for this as well
Spark SQL also has the sampleBy options too
sampleBy[T](col : _root_.scala.Predef.String, fractions : _root_.scala.Predef.Map[T, scala.Double], seed : scala.Long) : DataFrame
sampleBy[T](col : _root_.scala.Predef.String, fractions : java.util.Map[T, java.lang.Double], seed : scala.Long) : DataFrame
sampleBy[T](col : org.apache.spark.sql.Column, fractions : _root_.scala.Predef.Map[T, scala.Double], seed : scala.Long) : DataFrame
sampleBy[T](col : org.apache.spark.sql.Column, fractions : java.util.Map[T, java.lang.Double], seed : scala.Long) : DataFrame
You can reference : Stratified sampling in Spark and https://sparkbyexamples.com/spark/spark-sampling-with-examples/ for examples