Given a pyspark dataframe, for example:
ls = [
['1', 2],
['2', 7],
['1', 3],
['2',-6],
['1', 3],
['1', 5],
['1', 4],
['2', 7]
]
df = spark.createDataFrame(pd.DataFrame(ls, columns=['col1', 'col2']))
df.show()
+----+-----+
|col1| col2|
+----+-----+
| 1| 2|
| 2| 7|
| 1| 3|
| 2| -6|
| 1| 3|
| 1| 5|
| 1| 4|
| 2| 7|
+----+-----+
How can I apply a function to col2 values where col1 == '1' and store result in a new column? For example the function is:
f = x**2
Result should look like this:
+----+-----+-----+
|col1| col2| y|
+----+-----+-----+
| 1| 2| 4|
| 2| 7| null|
| 1| 3| 9|
| 2| -6| null|
| 1| 3| 9|
| 1| 5| 25|
| 1| 4| 16|
| 2| 7| null|
+----+-----+-----+
I tried defining a separate function, and use df.withColumn(y).when(condition,function) but it wouldn't work.
So what is a way to do this?