I've created the following graph:
spark = SparkSession.builder.appName('aggregate').getOrCreate()
vertices = spark.createDataFrame([('1', 'foo', 99),
('2', 'bar', 10),
('3', 'baz', 25),
('4', 'spam', 7)],
['id', 'name', 'value'])
edges = spark.createDataFrame([('1', '2'),
('1', '3'),
('3', '4')],
['src', 'dst'])
g = GraphFrame(vertices, edges)
I would like to aggregate the messages, such that for any given vertex we have a list of all values for its children vertices all the way to the edge. For example, from vertex 1
we have a child edge to vertex 3
which has a child edge to vertex 4
. We also have a child edge to 2
. That is:
(1) --> (3) --> (4)
\
\--> (2)
From 1
I'd like to collect all values from this path: [99, 10, 25, 7]
. Where 99
is the value for vertex 1
, 10
is the value of the child vertex 2
, 25
is the value at vertex 3
and 7
is the value at vertex 4
.
From 3
we'd have the values [25, 7]
, etc.
I can approximate this with aggregateMessages
:
agg = g.aggregateMessages(collect_list(AM.msg).alias('allValues'),
sendToSrc=AM.dst['value'],
sendToDst=None)
agg.show()
Which produces:
+---+---------+
| id|allValues|
+---+---------+
| 3| [7]|
| 1| [25, 10]|
+---+---------+
At 1
we have [25, 10]
which are the immediate child values, but we are missing 7
and the "self" value of 99
.
Similarly, I'm missing 25
for vertex 3
.
How can I aggregate messages "recursively", such that allValues
from child vertices are aggregated at the parent?