How do I efficiently create a new column in a DataFrame
that is a function of other rows in spark
?
This is a spark
implementation of the problem I described here:
from nltk.metrics.distance import edit_distance as edit_dist
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType
d = {
'id': [1, 2, 3, 4, 5, 6],
'word': ['cat', 'hat', 'hag', 'hog', 'dog', 'elephant']
}
spark_df = sqlCtx.createDataFrame(pd.DataFrame(d))
words_list = list(spark_df.select('word').collect())
get_n_similar = udf(
lambda word: len(
[
w for w in words_list if (w['word'] != word) and
(edit_dist(w['word'], word) < 2)
]
),
IntegerType()
)
spark_df.withColumn('n_similar', get_n_similar(col('word'))).show()
Output:
+---+--------+---------+
|id |word |n_similar|
+---+--------+---------+
|1 |cat |1 |
|2 |hat |2 |
|3 |hag |2 |
|4 |hog |2 |
|5 |dog |1 |
|6 |elephant|0 |
+---+--------+---------+
The problem here is that I don't know a way to tell spark
to compare the current row to the other rows in the Dataframe
without first collecting the values into a list
. Is there a way to apply a generic function of other rows without calling collect
?