2

I am currently trying to find efficient ways of grouping levels in a categorical column that have a low occurrence in columns of StringType(). I want to do this based on a percentage threshold, i.e. replace all values that occur in less than z% of the rows. Also, it is important that we can return the mapping between numerical values (after applying StringIndexer) and the original values.

So basically with a threshold of 25%, this dataframe:

+---+---+---+---+
| x1| x2| x3| x4|
+---+---+---+---+
|  a|  a|  a|  a|
|  b|  b|  a|  b|
|  a|  a|  a|  c|
|  b|  b|  a|  d|
|  c|  a|  a|  e|
+---+---+---+---+

Should become this:

+------+------+------+------+
|x1_new|x2_new|x3_new|x4_new|
+------+------+------+------+
|     a|     a|     a| other|
|     b|     b|     a| other|
|     a|     a|     a| other|
|     b|     b|     a| other|
| other|     a|     a| other|
+------+------+------+------+

where c has been replaced with other in column x1, and all values have been replaced with other in column x4, because they occur in less than 25% of the rows.

I was hoping to use a regular StringIndexer, and make use of the fact that values are ordered based on their frequency. We can calculate how many values to keep and replace all others with e.g. -1. The issue with this approach: This raises errors later within IndexToString, I assume because the metadata is lost.

My questions; is there a good way to do this? Are there built-in functions that I might be overlooking? Is there a way to keep the metadata?

Thanks in advance!


df = pd.DataFrame({'x1' : ['a','b','a','b','c'],  # a: 0.4, b: 0.4, c: 0.2
                   'x2' : ['a','b','a','b','a'],  # a: 0.6, b: 0.4, c: 0.0
                   'x3' : ['a','a','a','a','a'],  # a: 1.0, b: 0.0, c: 0.0
                   'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)
Florian
  • 24,425
  • 4
  • 49
  • 80

1 Answers1

3

I did some futher investigation and stumbled upon this post about adding metadata to a column in pyspark. Based on this, I was able to create a function called group_low_freq that I think is quite efficient; it uses the StringIndexer only once, and then modifies this column and the metadata to bin all elements that occur less than x% in a separate group called "other". Since we also modify the metadata, we are able to retrieve the Strings later on IndexToString. The function and an example are given below:


Code:

import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext, Column
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count as sparkcount, when, lit
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline
import json 

try:
    sc
except NameError:
    sc = ps.SparkContext()
    sqlContext = SQLContext(sc)

from pyspark.sql.functions import col

def withMeta(self, alias, meta):
    sc = ps.SparkContext._active_spark_context
    jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
    return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))

def group_low_freq(df,inColumns,threshold=.01,group_text='other'):
    """
    Index string columns and group all observations that occur in less then a threshold% of the rows in df per column.
    :param df: A pyspark.sql.dataframe.DataFrame
    :param inColumns: String columns that need to be indexed
    :param group_text: String to use as replacement for the observations that need to be grouped.
    """
    total = df.count()
    for string_col in inColumns:
        # Apply string indexer
        pipeline = Pipeline(stages=[StringIndexer(inputCol=string_col, outputCol="ix_"+string_col)])
        df = pipeline.fit(df).transform(df)

        # Calculate the number of unique elements to keep
        n_to_keep = df.groupby(string_col).agg((sparkcount(string_col)/total).alias('perc')).filter(col('perc')>threshold).count()

        # If elements occur below (threshold * number of rows), replace them with n_to_keep.
        this_meta = df.select('ix_' + string_col).schema.fields[0].metadata
        if n_to_keep != len(this_meta['ml_attr']['vals']):  
            this_meta['ml_attr']['vals'] = this_meta['ml_attr']['vals'][0:(n_to_keep+1)]
            this_meta['ml_attr']['vals'][n_to_keep] = group_text    
            df = df.withColumn('ix_'+string_col,when(col('ix_'+string_col)>=n_to_keep,lit(n_to_keep)).otherwise(col('ix_'+string_col)))

        # add the new column with correct metadata, remove original.
        df = df.withColumn('ix_'+string_col, withMeta(col('ix_'+string_col), "", this_meta))

    return df




# SAMPLE DATA -----------------------------------------------------------------

df = pd.DataFrame({'x1' : ['a','b','a','b','c'],  # a: 0.4, b: 0.4, c: 0.2
                   'x2' : ['a','b','a','b','a'],  # a: 0.6, b: 0.4, c: 0.0
                   'x3' : ['a','a','a','a','a'],  # a: 1.0, b: 0.0, c: 0.0
                   'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)

# TEST THE FUNCTION -----------------------------------------------------------

df = group_low_freq(df,df.columns,0.25)    

ix_cols = [x for x in df.columns if 'ix_' in x]
for string_col in ix_cols:    
    idx_to_string = IndexToString(inputCol=string_col, outputCol=string_col[3:]+'grouped')
    df = idx_to_string.transform(df)

df.show()

Output with a threshold of 25% (so each group had to occur in at least 25% of the rows):

    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
    | x1| x2| x3| x4|ix_x1|ix_x2|ix_x3|ix_x4|x1grouped|x2grouped|x3grouped|x4grouped|
    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
    |  a|  a|  a|  a|  0.0|  0.0|  0.0|  0.0|        a|        a|        a|    other|
    |  b|  b|  a|  b|  1.0|  1.0|  0.0|  0.0|        b|        b|        a|    other|
    |  a|  a|  a|  c|  0.0|  0.0|  0.0|  0.0|        a|        a|        a|    other|
    |  b|  b|  a|  d|  1.0|  1.0|  0.0|  0.0|        b|        b|        a|    other|
    |  c|  a|  a|  e|  2.0|  0.0|  0.0|  0.0|    other|        a|        a|    other|
    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
Florian
  • 24,425
  • 4
  • 49
  • 80