1

I have to find neighbors of a specific data point in a pyspark dataframe.

a= spark.createDataFrame([("A", [0,1]), ("B", [5,9]), ("D", [13,5])],["Letter", "distances"])

I have created this function that will take in the dataframe (DB) and then check the closest data points to a fixed point (Q) using the euclidean distance. It will filter out the relevant data points based on some epsilon value (eps) and return the subset.

def rangequery(DB, Q, eps):
    distance_udf = F.udf(lambda x: float(distance.euclidean(x, Q)), FloatType())
    df_neigh =DB.withColumn('euclid_distances', distance_udf(F.col('distances')))
    return df_neigh.filter(df_neigh['euclid_distances'] <= eps)

But now I need to run this function for every single point in the data frame

So I do the following.

def check_neighbours(distance):
    df = rangequery(a,distances, 9)
    if df.count()>=1:
        return "Has Neighbours"
    else:
        return "No Neighbours"       
udf_neigh=udf(check_neighbours, StringType())
a.withColumn("label", udf_neigh( a["distances"])).show()

I get the following error when i try to run this code.

PicklingError: Could not serialize object: Py4JError: An error occurred while calling o380.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
    at py4j.Gateway.invoke(Gateway.java:272)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:214)
    at java.lang.Thread.run(Thread.java:745)
Bryce Ramgovind
  • 3,127
  • 10
  • 41
  • 72
  • I asked a similar question [here](https://stackoverflow.com/questions/48174484/new-dataframe-column-as-a-generic-function-of-other-rows-spark) and you may find [this answer](https://stackoverflow.com/a/48174947/5858851) useful. – pault Jan 09 '18 at 19:47

1 Answers1

0

Borrowing heavily from this answer, here's one way to do it. Consider the following example:

from pyspark.sql.functions import col, udf
# create dummy dataset
DB = sqlCtx.createDataFrame(
    [("A", [0,1]), ("B", [5,9]), ("D", [13,5])],
    ["Letter", "distances"]
)

# Define your distance metric as a udf 
from scipy.spatial import distance
distance_udf = udf(lambda x, y: float(distance.euclidean(x, y)), FloatType())

# Use crossJoin() to compute distances.
eps = 9  # minimum distance 
DB.alias("l")\
    .crossJoin(DB.alias("r"))\
    .where(distance_udf(col("l.distances"), col("r.distances")) < eps)\
    .groupBy("l.letter", "l.distances")\
    .count()\
    .withColumn("count", col("count") - 1)\
    .withColumn("label", udf(lambda x: "Has Neighbours" if x >= 1 else "No Neighbours")(col("count")))\
    .sort('letter')\
    .show()

Output:

+------+---------+-----+--------------+
|letter|distances|count|         label|
+------+---------+-----+--------------+
|     A|   [0, 1]|    0| No Neighbours|
|     B|   [5, 9]|    1|Has Neighbours|
|     D|  [13, 5]|    1|Has Neighbours|
+------+---------+-----+--------------+

Where the .withColumn("count", col("count") - 1) is done because we know that each column will have itself as a trivial neighbor. (You can remove this line depending on your needs.)

Your code as written doesn't work because, as mentioned by @user8371915 in the linked post:

you cannot reference distributed DataFrame in udf

pault
  • 41,343
  • 15
  • 107
  • 149