I did some futher investigation and stumbled upon this post about adding metadata to a column in pyspark. Based on this, I was able to create a function called group_low_freq
that I think is quite efficient; it uses the StringIndexer
only once, and then modifies this column and the metadata to bin all elements that occur less than x%
in a separate group called "other". Since we also modify the metadata, we are able to retrieve the Strings later on IndexToString
. The function and an example are given below:
Code:
import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext, Column
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count as sparkcount, when, lit
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline
import json
try:
sc
except NameError:
sc = ps.SparkContext()
sqlContext = SQLContext(sc)
from pyspark.sql.functions import col
def withMeta(self, alias, meta):
sc = ps.SparkContext._active_spark_context
jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))
def group_low_freq(df,inColumns,threshold=.01,group_text='other'):
"""
Index string columns and group all observations that occur in less then a threshold% of the rows in df per column.
:param df: A pyspark.sql.dataframe.DataFrame
:param inColumns: String columns that need to be indexed
:param group_text: String to use as replacement for the observations that need to be grouped.
"""
total = df.count()
for string_col in inColumns:
# Apply string indexer
pipeline = Pipeline(stages=[StringIndexer(inputCol=string_col, outputCol="ix_"+string_col)])
df = pipeline.fit(df).transform(df)
# Calculate the number of unique elements to keep
n_to_keep = df.groupby(string_col).agg((sparkcount(string_col)/total).alias('perc')).filter(col('perc')>threshold).count()
# If elements occur below (threshold * number of rows), replace them with n_to_keep.
this_meta = df.select('ix_' + string_col).schema.fields[0].metadata
if n_to_keep != len(this_meta['ml_attr']['vals']):
this_meta['ml_attr']['vals'] = this_meta['ml_attr']['vals'][0:(n_to_keep+1)]
this_meta['ml_attr']['vals'][n_to_keep] = group_text
df = df.withColumn('ix_'+string_col,when(col('ix_'+string_col)>=n_to_keep,lit(n_to_keep)).otherwise(col('ix_'+string_col)))
# add the new column with correct metadata, remove original.
df = df.withColumn('ix_'+string_col, withMeta(col('ix_'+string_col), "", this_meta))
return df
# SAMPLE DATA -----------------------------------------------------------------
df = pd.DataFrame({'x1' : ['a','b','a','b','c'], # a: 0.4, b: 0.4, c: 0.2
'x2' : ['a','b','a','b','a'], # a: 0.6, b: 0.4, c: 0.0
'x3' : ['a','a','a','a','a'], # a: 1.0, b: 0.0, c: 0.0
'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
# TEST THE FUNCTION -----------------------------------------------------------
df = group_low_freq(df,df.columns,0.25)
ix_cols = [x for x in df.columns if 'ix_' in x]
for string_col in ix_cols:
idx_to_string = IndexToString(inputCol=string_col, outputCol=string_col[3:]+'grouped')
df = idx_to_string.transform(df)
df.show()
Output with a threshold of 25% (so each group had to occur in at least 25% of the rows):
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| x1| x2| x3| x4|ix_x1|ix_x2|ix_x3|ix_x4|x1grouped|x2grouped|x3grouped|x4grouped|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
| a| a| a| a| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| b| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| a| a| a| c| 0.0| 0.0| 0.0| 0.0| a| a| a| other|
| b| b| a| d| 1.0| 1.0| 0.0| 0.0| b| b| a| other|
| c| a| a| e| 2.0| 0.0| 0.0| 0.0| other| a| a| other|
+---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+