0

How do I efficiently create a new column in a DataFrame that is a function of other rows in spark ?

This is a spark implementation of the problem I described here:

from nltk.metrics.distance import edit_distance as edit_dist
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType

d = {
    'id': [1, 2, 3, 4, 5, 6],
    'word': ['cat', 'hat', 'hag', 'hog', 'dog', 'elephant']
}

spark_df = sqlCtx.createDataFrame(pd.DataFrame(d))
words_list = list(spark_df.select('word').collect())

get_n_similar = udf(
    lambda word: len(
        [
            w for w in words_list if (w['word'] != word) and 
            (edit_dist(w['word'], word) < 2)
        ]
    ),
    IntegerType()
)

spark_df.withColumn('n_similar', get_n_similar(col('word'))).show()

Output:

+---+--------+---------+
|id |word    |n_similar|
+---+--------+---------+
|1  |cat     |1        |
|2  |hat     |2        |
|3  |hag     |2        |
|4  |hog     |2        |
|5  |dog     |1        |
|6  |elephant|0        |
+---+--------+---------+

The problem here is that I don't know a way to tell spark to compare the current row to the other rows in the Dataframe without first collecting the values into a list. Is there a way to apply a generic function of other rows without calling collect?

pault
  • 41,343
  • 15
  • 107
  • 149

1 Answers1

2

The problem here is that I don't know a way to tell spark to compare the current row to the other rows in the Dataframe without first collecting the values into a list.

UDF is not an option here (you cannot reference distributed DataFrame in udf) Direct translation of your logic is Cartesian product and aggregate:

from pyspark.sql.functions import levenshtein, col

result = (spark_df.alias("l")
    .crossJoin(spark_df.alias("r"))
    .where(levenshtein("l.word", "r.word") < 2)
    .where(col("l.word") != col("r.word"))
    .groupBy("l.id", "l.word")
    .count())

but in practice you should try to do something more efficient: Efficient string matching in Apache Spark

Depending on the problem, you should try to find other approximations to avoid full Cartesian product.

If you want to keep data without matches you can skip one filter:

(spark_df.alias("l")
    .crossJoin(spark_df.alias("r"))
    .where(levenshtein("l.word", "r.word") < 2)
    .groupBy("l.id", "l.word")
    .count()
    .withColumn("count", col("count") - 1))

or (slower, but more generic), join back with reference:

(spark_df
    .select("id", "word")
    .distinct()
    .join(result, ["id", "word"], "left")
    .na.fill(0))
Alper t. Turker
  • 34,230
  • 9
  • 83
  • 115