1

I have the following situation.

+--------------------+
|                   p|
+--------------------+
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
|[0.99998416412131...|
+--------------------+

This is a list of Row() objects.

[Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
 Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
 Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
 Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06]),
 Row(p=[0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06])]

I am trying to filter down this column into a new column named "maxClass" that returned the np.argmax(row)[0] for all rows. Below is my best shot at it, but I simply cannot get the linguistics of using this package.

def f(row):
    return np.argmax(np.array(row.p))[0]
results=probs.rdd.map(lambda x:f(x))
results
Dharman
  • 30,962
  • 25
  • 85
  • 135
bmc
  • 817
  • 1
  • 12
  • 23
  • Does removing the `[0]` from your `f` function work for you? It's hard to tell without knowing exactly what your desired output is. – pault Sep 17 '19 at 22:06
  • Sorry it was so ambiguous. I just wanted the index associated with the largest value. I am getting a little closer, but still unsuccessful. `import pyspark.sql.functions as f from pyspark.sql.functions import udf def custom_function(row): return np.array(row["p"]).argmax() udf_custom_function = udf(custom_function) new = probs.withColumn("p_max", udf_custom_function("p"))` – bmc Sep 18 '19 at 15:11
  • 1
    Don't add [SOLVED] to your question. Instead, post your own solution as an answer if you think it will be useful for others OR delete the question. – pault Sep 18 '19 at 15:35
  • Also, you can [do this without a `udf`](https://stackoverflow.com/questions/38296609/spark-functions-vs-udf-performance?rq=1) if you know the size of the array column (which seems likely since this looks like a multi-class classification problem). – pault Sep 18 '19 at 15:36

1 Answers1

2

For the sake of completeness and as pault suggested here is a solution without using UDF and numpy. Instead array_position and array_max is used:

import pyspark.sql.functions as f

df = spark.createDataFrame([
  ([0.9999841641213133, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],),
  ([0.9999841641213134, 0.99999, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],),
  ([0.9999841641213135, 5.975696995141415e-06, 1.3699249952858219e-06, 1.4817184271708493e-06, 2.9022272149130313e-07, 1.4883436072406822e-06, 2.2234697862933896e-06, 3.006502154124559e-06],)]) \
.toDF("p")

df.select(
  f.expr('array_position(cast(p as array<decimal(16, 16)>), cast(array_max(p) as decimal(16, 16))) - 1').alias("max_indx")
).show()

# +--------+
# |max_indx|
# +--------+
# |       0|
# |       1|
# |       0|
# +--------+
abiratsis
  • 7,051
  • 3
  • 28
  • 46
  • This is very close to what I want, but I need the index associated with it, not the value itself. As these are probabilities associated with a class, I have a lookup table for the index-class mapping to the different categorical variable associated with it. – bmc Sep 18 '19 at 15:12
  • They want the position of the max – pault Sep 18 '19 at 15:12