5

I'm working with the following snippet:

from cape_privacy.pandas.transformations import Tokenizer

max_token_len = 5


@pandas_udf("string")

def Tokenize(column: pd.Series)-> pd.Series:
  tokenizer = Tokenizer(max_token_len)
  return tokenizer(column)


spark_df = spark_df.withColumn("name", Tokenize("name"))

Since Pandas UDF only uses Pandas series I'm unable to pass the max_token_len argument in the function call Tokenize("name").

Therefore I have to define the max_token_len argument outside the scope of the function.

The workarounds provided in this question weren't really helpful. Are there any other possible workarounds or alternatives to this issue?

Please Advise

The Singularity
  • 2,428
  • 3
  • 19
  • 48

2 Answers2

16

After trying a myriad of approaches, I found an effortless solution as illustrated below:

I created a wrapper function (Tokenize_wrapper) to wrap the Pandas UDF (Tokenize_udf) with the wrapper function returning the Pandas UDF's function call.

def Tokenize_wrapper(column, max_token_len=10):

  @pandas_udf("string")
  def Tokenize_udf(column: pd.Series) -> pd.Series:
    tokenizer = Tokenizer(max_token_len)
    return tokenizer(column)

  return Tokenize_udf(column)



df = df.withColumn("Name", Tokenize_wrapper("Name", max_token_len=5))

Using partial functions (@Vaebhav's answer) did actually make this issue's implementation difficult.

The Singularity
  • 2,428
  • 3
  • 19
  • 48
0

You can achieve this by using partial and directly specifying an additional argument(s) in your UDF signature

Data Preparation

input_list = [
               (1,None,111)    
               ,(1,None,120)
              ,(1,None,121)
              ,(1,None,124)
              ,(1,'p1',125)
              ,(1,None,126)
              ,(1,None,146)
              ,(1,None,147)
             ]

sparkDF = sql.createDataFrame(input_list,['id','p_id','timestamp'])

sparkDF.show()

+---+----+---------+
| id|p_id|timestamp|
+---+----+---------+
|  1|null|      111|
|  1|null|      120|
|  1|null|      121|
|  1|null|      124|
|  1|  p1|      125|
|  1|null|      126|
|  1|null|      146|
|  1|null|      147|
+---+----+---------+

Partial


def add_constant(inp,cnst=5):
    return inp + cnst


cnst_add = 10

partial_func = partial(add_constant,cnst=cnst_add)

sparkDF = sparkDF.withColumn('Constant',partial_func(F.col('timestamp')))
                 
sparkDF.show()

+---+----+---------+----------------+
| id|p_id|timestamp|Constant_Partial|
+---+----+---------+----------------+
|  1|null|      111|             121|
|  1|null|      120|             130|
|  1|null|      121|             131|
|  1|null|      124|             134|
|  1|  p1|      125|             135|
|  1|null|      126|             136|
|  1|null|      146|             156|
|  1|null|      147|             157|
+---+----+---------+----------------+

UDF Signature

cnst_add = 10

add_constant_udf = F.udf(lambda x : add_constant(x,cnst_add),IntegerType())


sparkDF = sparkDF.withColumn('Constant_UDF',add_constant_udf(F.col('timestamp')))

sparkDF.show()

+---+----+---------+------------+
| id|p_id|timestamp|Constant_UDF|
+---+----+---------+------------+
|  1|null|      111|         121|
|  1|null|      120|         130|
|  1|null|      121|         131|
|  1|null|      124|         134|
|  1|  p1|      125|         135|
|  1|null|      126|         136|
|  1|null|      146|         156|
|  1|null|      147|         157|
+---+----+---------+------------+

Similarly You can transform your function as below -

from functools import partial

max_token_len = 5

def Tokenize(column: pd.Series,max_token_len=10)-> pd.Series:
  tokenizer = Tokenizer(max_token_len)
  return tokenizer(column)

Tokenize_udf = F.udf(lambda x : Tokenize(x,max_token_len),StringType())

Tokenize_partial = partial(Tokenize,max_token_len=max_token_len)

spark_df = spark_df.withColumn("name", Tokenize_udf("name"))
spark_df = spark_df.withColumn("name", Tokenize_partial("name"))

Vaebhav
  • 4,672
  • 1
  • 13
  • 33