1

I've created the following graph:

spark = SparkSession.builder.appName('aggregate').getOrCreate()
vertices = spark.createDataFrame([('1', 'foo', 99), 
                                  ('2', 'bar', 10),
                                 ('3', 'baz', 25),
                                 ('4', 'spam', 7)],
                                 ['id', 'name', 'value'])

edges = spark.createDataFrame([('1', '2'), 
                               ('1', '3'),
                               ('3', '4')],
                              ['src', 'dst'])

g = GraphFrame(vertices, edges)

I would like to aggregate the messages, such that for any given vertex we have a list of all values for its children vertices all the way to the edge. For example, from vertex 1 we have a child edge to vertex 3 which has a child edge to vertex 4. We also have a child edge to 2. That is:

(1) --> (3) --> (4)
  \
   \--> (2)

From 1 I'd like to collect all values from this path: [99, 10, 25, 7]. Where 99 is the value for vertex 1, 10 is the value of the child vertex 2, 25 is the value at vertex 3 and 7 is the value at vertex 4.

From 3 we'd have the values [25, 7], etc.

I can approximate this with aggregateMessages:

agg = g.aggregateMessages(collect_list(AM.msg).alias('allValues'),
                          sendToSrc=AM.dst['value'],
                          sendToDst=None)

agg.show()

Which produces:

+---+---------+
| id|allValues|
+---+---------+
|  3|      [7]|
|  1| [25, 10]|
+---+---------+

At 1 we have [25, 10] which are the immediate child values, but we are missing 7 and the "self" value of 99.

Similarly, I'm missing 25 for vertex 3.

How can I aggregate messages "recursively", such that allValues from child vertices are aggregated at the parent?

Julio
  • 2,261
  • 4
  • 30
  • 56

1 Answers1

0

Adapting this answer for your question, and wrangled the result of that answer to get your desired output. I admit it's a very ugly solution, but I hope it'll be helpful for you as a starting point to work towards a more efficient and elegant implementation.

from graphframes import GraphFrame
from graphframes.lib import Pregel
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *


vertices = spark.createDataFrame([('1', 'foo', 99), 
                                  ('2', 'bar', 10),
                                 ('3', 'baz', 25),
                                 ('4', 'spam', 7)],
                                 ['id', 'name', 'value'])

edges = spark.createDataFrame([('1', '2'), 
                               ('1', '3'),
                               ('3', '4')],
                              ['src', 'dst'])

g = GraphFrame(vertices, edges)

### Adapted from previous answer

vertColSchema = StructType()\
      .add("dist", DoubleType())\
      .add("node", StringType())\
      .add("path", ArrayType(StringType(), True))

def vertexProgram(vd, msg):
    if msg == None or vd.__getitem__(0) < msg.__getitem__(0):
        return (vd.__getitem__(0), vd.__getitem__(1), vd.__getitem__(2))
    else:
        return (msg.__getitem__(0), vd.__getitem__(1), msg.__getitem__(2))

vertexProgramUdf = F.udf(vertexProgram, vertColSchema)

def sendMsgToDst(src, dst):
    srcDist = src.__getitem__(0)
    dstDist = dst.__getitem__(0)
    if srcDist < (dstDist - 1):
        return (srcDist + 1, src.__getitem__(1), src.__getitem__(2) + [dst.__getitem__(1)])
    else:
        return None

sendMsgToDstUdf = F.udf(sendMsgToDst, vertColSchema)

def aggMsgs(agg):
    shortest_dist = sorted(agg, key=lambda tup: tup[1])[0]
    return (shortest_dist.__getitem__(0), shortest_dist.__getitem__(1), shortest_dist.__getitem__(2))

aggMsgsUdf = F.udf(aggMsgs, vertColSchema)

result = (
    g.pregel.withVertexColumn(
        colName = "vertCol",

        initialExpr = F.when(
            F.col("id") == 1,
            F.struct(F.lit(0.0), F.col("id"), F.array(F.col("id")))
        ).otherwise(
            F.struct(F.lit(float("inf")), F.col("id"), F.array(F.lit("")))
        ).cast(vertColSchema),

        updateAfterAggMsgsExpr = vertexProgramUdf(F.col("vertCol"), Pregel.msg())
    )
    .sendMsgToDst(sendMsgToDstUdf(F.col("src.vertCol"), Pregel.dst("vertCol")))
    .aggMsgs(aggMsgsUdf(F.collect_list(Pregel.msg())))
    .setMaxIter(3)    ## This should be greater than the max depth of the graph
    .setCheckpointInterval(1)
    .run()
)

df = result.select("vertCol.node", "vertCol.path").repartition(1)
df.show()
+----+---------+
|node|     path|
+----+---------+
|   1|      [1]|
|   2|   [1, 2]|
|   3|   [1, 3]|
|   4|[1, 3, 4]|
+----+---------+

### Wrangling the dataframe to get desired output

final = df.select(
    'node',
    F.posexplode_outer('path')
).withColumn(
    'children', 
    F.collect_list('col').over(Window.partitionBy('node').orderBy(F.desc('pos')))
).groupBy('col').agg(
    F.array_distinct(F.flatten(F.collect_list('children'))).alias('children')
).alias('t1').repartition(1).join(
    vertices,
    F.array_contains(F.col('t1.children'), vertices.id)
).groupBy('col').agg(
    F.collect_list('value').alias('values')
).withColumnRenamed('col', 'id').orderBy('id')

final.show()
+---+---------------+
| id|         values|
+---+---------------+
|  1|[99, 10, 25, 7]|
|  2|           [10]|
|  3|        [25, 7]|
|  4|            [7]|
+---+---------------+
mck
  • 40,932
  • 13
  • 35
  • 50