Just to complement the accepted answer: one can now do something very similar to what the OP tried to do, i.e., use the **
operator, or even Python's builtin pow
:
from pyspark.sql import SparkSession
from pyspark.sql.functions import pow as pow_
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([(1, ), (2, ), (3, ), (4, ), (5, ), (6, )], 'n: int')
df = df.withColumn('pyspark_pow', pow_(df['n'], df['n'])) \
.withColumn('python_pow', pow(df['n'], df['n'])) \
.withColumn('double_star_operator', df['n'] ** df['n'])
df.show()
+---+-----------+----------+--------------------+
| n|pyspark_pow|python_pow|double_star_operator|
+---+-----------+----------+--------------------+
| 1| 1.0| 1.0| 1.0|
| 2| 4.0| 4.0| 4.0|
| 3| 27.0| 27.0| 27.0|
| 4| 256.0| 256.0| 256.0|
| 5| 3125.0| 3125.0| 3125.0|
| 6| 46656.0| 46656.0| 46656.0|
+---+-----------+----------+--------------------+
As one can see, both PySpark's and Python's pow
return the same result, as well as the **
operator. It also works when one of the arguments is a scalar:
df = df.withColumn('pyspark_pow', pow_(2, df['n'])) \
.withColumn('python_pow', pow(2, df['n'])) \
.withColumn('double_star_operator', 2 ** df['n'])
df.show()
+---+-----------+----------+--------------------+
| n|pyspark_pow|python_pow|double_star_operator|
+---+-----------+----------+--------------------+
| 1| 2.0| 2.0| 2.0|
| 2| 4.0| 4.0| 4.0|
| 3| 8.0| 8.0| 8.0|
| 4| 16.0| 16.0| 16.0|
| 5| 32.0| 32.0| 32.0|
| 6| 64.0| 64.0| 64.0|
+---+-----------+----------+--------------------+
I believe the reason Python's pow
now work on PySpark columns, is the fact that pow
is equivalent to the **
operator when used with only two arguments (see docs, here), and the **
operator uses the objects own implementation of the power operation, if it is defined for the object being operated on (see this SO response here).
Apparently, PySpark's Column
has the proper definitions for __pow__
operator (see source code for Column
).
I am not sure why the **
operator did not work originally, but I am assuming it is related to the fact that - at the time - Column
was defined differently.
The stack used for testing was Python 3.8.5 and PySpark 3.1.1, but I have seen this behavior for PySpark >= 2.4 as well.