I have seen in other posts of this being done for dataframes: https://stackoverflow.com/a/52992212/4080521
But I am trying to figure out how I can write an udf for a cumulative product.
Assuming I have a very basic table
Input data:
+----+
| val|
+----+
| 1 |
| 2 |
| 3 |
+----+
If i want to take the sum of this I can simply do something like
sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)
and this simply works because SUM is a pre defined function.
How would I define something similar for multiplication (or even how can I implement sum in an UDF
myself)?
Trying the following
sparkSession.createOrReplaceTempView("_Period0")
val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)
spark.sql("""Select prod(table.vals) from table""").show(100, false)
I get the following error:
Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)
Obviously each specific cell is not an array, but it seems the udf needs to take in an array to perform the aggregation. Is it even possible with spark sql?