2

I'm trying to get the majority vote of a few different models for a binary classification problem.

I managed to create compile a spark table from a few different spark tables using

LR.createOrReplaceTempView("lr")
RF.createOrReplaceTempView("rf")
DT.createOrReplaceTempView("dt")
GBT.createOrReplaceTempView("gbt")
majority = spark.sql("SELECT lr.label, lr, rf, dt, gbt FROM lr, rf, dt, gbt")

The output of majority looks like

+-----+---+---+---+---+
|label| lr| rf| dt|gbt|
+-----+---+---+---+---+
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
|  0.0|0.0|0.0|0.0|0.0|
+-----+---+---+---+---+

I'm trying to create a column that takes the majority vote (mode) from those four columns. I've looked into this post, but couldn't exactly get what I want.

Thanks so much for helping!

timxymo1225
  • 481
  • 1
  • 4
  • 13
  • So you want the [row-wise mode](https://stackoverflow.com/questions/56446362/mode-of-row-as-a-new-column-in-pyspark-dataframe), as opposed to the column wise? What happens if there is a tie? – pault Aug 08 '19 at 21:52
  • @pault row-wise mode is exactly what I want. Regarding the tie, I think I will add another column so its odd number. – timxymo1225 Aug 09 '19 at 13:32

2 Answers2

1

If you're looking for how to calculate the row-wise mode in spark, refer to Mode of row as a new column in PySpark DataFrame. However, you can get your desired result without computing the mode.

Since this is a binary classification problem, each column can only take on the value of 1.0 or 0.0. Thus you can simplify the voting, by taking the row-wise mean.

You can use the following rule:

  • If the mean of the values in the row >= 0.5, then at least half of the classifiers predicted a 1 and the label should be 1
  • If the mean of the values in < 0.5, then a majority of the classifiers predicted a 0 and the labels should be 0.

I am making the assumption that a tie goes in favor of the positive class label.

You can implement this as such:

# adapted from https://stackoverflow.com/a/32672278
from functools import reduce
from operator import add
from pyspark.sql.functions import col, lit

n = lit(len([c for c in majority.columns if c != "label"]))
rowMean  = (reduce(add, (col(x) for x in majority.columns if x != "label")) / n)   

majority = majority.withColumn("label", (rowMean >= 0.5).cast("int"))

Alternatively, you can just check if the count of columns that are greater than 0 is at least n/2:

n = lit(len([c for c in majority.columns if c != "label"]))

# count of columns with a positive label
positiveCount = reduce(
    add, 
    ((col(x)>0).cast("int") for x in majority.columns if x != "label")
 )

majority = majority.withColumn(
    "label", 
     (positiveCount >= (n / 2.0)).cast("int")
)
pault
  • 41,343
  • 15
  • 107
  • 149
  • I do have a table for average probability which uses the same idea, but this also gets the job done, thanks! – timxymo1225 Aug 09 '19 at 13:44
  • Using the average probability does not necessarily yield the same result, because there is no rule that states you must use the same threshold (i.e `p>0.5`) for each model. You may have other requirements (maximize KS, minimize False Positive Rate, etc) that can lead to a different choice of threshold for each model. – pault Aug 09 '19 at 13:58
  • one thing I would edit in the code is ```n = lit(len(c for c in majority.columns if c != "label"))``` since I got an error saying that couldn't find length for that data type. I just did ```n = len(majority.columns) - 1``` instead. – timxymo1225 Aug 09 '19 at 19:44
  • You are right, it should be square brackets in addition to the parentheses - that would make it a list and not a generator. Updated, but your way works too. – pault Aug 09 '19 at 19:48
0

Using mainly Spark SQL:

df.createOrReplaceTempView("df")
cols_to_mode = ["lm", "lr", "dt", "gbt"]

qry_pt1 = " ,".join([f""" agg_{c} as (
  select count(*), {c} as mode_col from df group by {c} order by count(*) desc)""" for c in cols_to_mode])
qry_pt2 = " union all ".join([f"(select mode_col, '{c}' as col from agg_{c} limit 1)" for c in cols_to_mode])
df_modes = spark.sql(qry_pt1 + qry_pt2)

Scala equivalent:

val colsToMode = Seq("lr", "lm", "dt", "gbt")
val qryPt1 = "with " + colsToMode.map(c => s"""
  agg_${c} as (
  select count(*), ${c} as mode_col from df group by ${c} order by count(*) desc)
  """).mkString(" ,")

val qryPt2 = colsToMode.map(c => s"(select mode_col, '${c}' as col from agg_${c} limit 1)").mkString(" union all ")
val dfModes = spark.sql(qryPt1 + qryPt2)
datapug
  • 2,261
  • 1
  • 17
  • 33