1

Given a pyspark dataframe, for example:

ls = [
    ['1', 2],
    ['2', 7],
    ['1', 3],
    ['2',-6],
    ['1', 3],
    ['1', 5],
    ['1', 4],
    ['2', 7]
]
df = spark.createDataFrame(pd.DataFrame(ls, columns=['col1', 'col2']))
df.show()

+----+-----+
|col1| col2|
+----+-----+
|   1|    2|
|   2|    7|
|   1|    3|
|   2|   -6|
|   1|    3|
|   1|    5|
|   1|    4|
|   2|    7|
+----+-----+

How can I apply a function to col2 values where col1 == '1' and store result in a new column? For example the function is:

f = x**2

Result should look like this:

+----+-----+-----+
|col1| col2|    y|
+----+-----+-----+
|   1|    2|    4|
|   2|    7| null|
|   1|    3|    9|
|   2|   -6| null|
|   1|    3|    9|
|   1|    5|   25|
|   1|    4|   16|
|   2|    7| null|
+----+-----+-----+

I tried defining a separate function, and use df.withColumn(y).when(condition,function) but it wouldn't work.

So what is a way to do this?

kiwii
  • 63
  • 7

1 Answers1

2

I hope this helps:

def myFun(x):
  return (x**2).cast(IntegerType())

df2 = df.withColumn("y", when(df.col1 == 1, myFun(df.col2)).otherwise(None))

df2.show()

+----+----+----+
|col1|col2|   y|
+----+----+----+
|   1|   2|   4|
|   2|   7|null|
|   1|   3|   9|
|   2|  -6|null|
|   1|   3|   9|
|   1|   5|  25|
|   1|   4|  16|
|   2|   7|null|
+----+----+----+
michalrudko
  • 1,432
  • 2
  • 16
  • 30