9

I have a user-defined function:

calc = udf(calculate, FloatType())

param1 = "A"

result = df.withColumn('col1', calc(col('type'), col('pos'))).groupBy('pk').sum('events')

def calculate(type, pos):
   if param1=="A":
       a, b = [ 0.05, -0.06 ]
   else:
       a, b = [ 0.15, -0.16 ]
   return a * math.pow(type, b) * max(pos, 1)

I need to pass a parameter param1 to this udf. How can I do it?

Dinosaurius
  • 8,306
  • 19
  • 64
  • 113

1 Answers1

14

You can use lit or typedLit as a parameter for your udf like this:

In Python:

from pyspark.sql.functions import udf, col, lit
mult = udf(lambda value, multiplier: value * multiplier)
df = spark.sparkContext.parallelize([(1,),(2,),(3,)]).toDF()
df.select(mult(col("_1"), lit(3)))

In Scala:

import org.apache.spark.sql.functions.{udf, col, lit}
val mult = udf((value: Double, multiplier: Double) => value * multiplier)
val df = sparkContext.parallelize((1 to 10)).toDF
df.select(mult(col("value"), lit(3)))
Paul V
  • 196
  • 1
  • 8