2

My question has two parts to it. First one is to understand the way Spark works and the second one is on optimization.

I have a spark dataframe which has multiple categorical variables. For each of these categorical variables I am adding a new column wherein each row is the frequency of the corresponding level.

For example

Date_Built  Square_Footage  Num_Beds    Num_Baths   State   Price     Freq_State
01/01/1920  1700            3           2           NY      700000    4500

Here for State (a categorical variable), I am adding a new variable Freq_State. The level NY appears 4500 times in the dataset so this row gets 4500 in the Freq_State column.

I have multiple such columns where I am adding a column bearing frequency of corresponding levels.

This is the code I am using for achieving this

def calculate_freq(df, categorical_cols):
    for each_cat_col in categorical_cols:
        _freq = df.select(each_cat_col).groupBy(each_cat_col).count()
        df = df.join(_freq, each_cat_col, "inner")
    return df

Part 1

Here, as you can see, I am updating the dataframe in the for loop. Is this way of updating a dataframe advisable when I'm running this code on a cluster? I wouldn't have been concerned about this if it was a pandas dataframe. But I am not certain when the context changes to spark.

Also, would it make a difference if I was simply running the above process in a loop and not inside a function?

Part 2

Is there a more optimized way to do this? Here I am joining each time I enter a loop? Can this be avoided

Clock Slave
  • 7,627
  • 15
  • 68
  • 109

1 Answers1

1

Is there a more optimized way to do this?

What are possible alternatives?

  1. You could use Window functions

    def calculate_freq(df, categorical_cols):
        for cat_col in categorical_cols:
            w = Window.partitionBy(cat_col)
            df = df.withColumn("{}_freq".format(each_cat_col), count("*").over(w))
        return df
    

    Should you? No. Unlike join it will always require a full shuffle of the non-aggregated DataFrame.

  2. You could melt and use single local object (this requires all categorical columns to be of the same type):

    from itertools import groupby
    
    for c in categorical_cols:
         df = df.withColumn(c, df[c].cast("string"))
    
    
    rows = (melt(df, id_vars=[], value_vars=categorical_cols)
            .groupBy("variable", "value").count().collect())
    
    mapping = {k: {x.value: x["count"] for x in v} 
              for k, v in groupby(sorted(rows), lambda x: x.variable)}
    

    And use udf to add values:

    from pyspark.sql.functions import udf
    
    def get_count(mapping_c):
        @udf("bigint")
        def _(x):
            return mapping_c.get(x)
        return _
    
    
    for c in categorical_cols:
        df = df.withColumn("{}_freq".format(c), get_count(mapping[c])(c))
    

    Should you? Maybe. Unlike iterative join it requires only a single action to compute all statistics. If result is small (expected with categorical variables) you can get a moderate performance boost.

  3. Add broadcast hint.

    from pyspark.sql.functions import broadcast
    
    def calculate_freq(df, categorical_cols):
        for each_cat_col in categorical_cols:
            _freq = df.select(each_cat_col).groupBy(each_cat_col).count()
        df = df.join(broadcast(_freq), each_cat_col, "inner")
        return df
    

    Spark should broadcast automatically, so it shouldn't change a thing, but it is always better to help planner.

Also, would it make a difference if I was simply running the above process in a loop and not inside a function?

Ignoring code maintainability and testability it would not.

Alper t. Turker
  • 34,230
  • 9
  • 83
  • 115