9

I am new to Python and stuck with building a hierarchy out of a relational dataset.
It would be of immense help if someone has an idea on how to proceed with this.

I have a relational data-set with data like

_currentnode,  childnode_  
 root,         child1  
 child1,       leaf2  
 child1,       child3  
 child1,       leaf4  
 child3,       leaf5  
 child3,       leaf6  

so-on. I am looking for some python or pyspark code to
build a hierarchy dataframe like below

_level1, level2,  level3,  level4_  
root,    child1,  leaf2,   null  
root,    child1,  child3,  leaf5  
root,    child1,  child3,  leaf6  
root,    child1,  leaf4,   null  

The data is alpha-numerics and is a huge dataset[~50mil records].
Also, the root of the hierarchy is known and can be hardwired in the code.
So in the example, above, the root of the hierarchy is 'root'.

mck
  • 40,932
  • 13
  • 35
  • 50
Vardhan
  • 402
  • 5
  • 13

1 Answers1

10

Shortest Path with Pyspark

The input data can be interpreted as a graph with the connections between currentnode and childnode. Then the question is what is the shortest path between the root node and all leaf nodes and is called single source shortest path.

Spark has Graphx to handle parallel computations of graphs. Unfortunately, GraphX does not provide a Python API (more details can be found here). A graph library with Python support is GraphFrames. GraphFrames uses parts of GraphX.

Both GraphX and GraphFrames provide an solution for sssp. Unfortunately again, both implementations return only the length of the shortest paths, not the paths themselves (GraphX and GraphFrames). But this answer provides an implementation of the algorithm for GraphX and Scala that also returns the paths. All three solutions use Pregel.

Translating the aforementioned answer to GraphFrames/Python:

1. Data preparation

Provide unique IDs for all nodes and change the column names so that they fit to the names described here

import pyspark.sql.functions as F

df = ...

vertices = df.select("currentnode").withColumnRenamed("currentnode", "node").union(df.select("childnode")).distinct().withColumn("id", F.monotonically_increasing_id()).cache()

edges = df.join(vertices, df.currentnode == vertices.node).drop(F.col("node")).withColumnRenamed("id", "src")\
        .join(vertices, df.childnode== vertices.node).drop(F.col("node")).withColumnRenamed("id", "dst").cache() 
Nodes                   Edges
+------+------------+   +-----------+---------+------------+------------+
|  node|          id|   |currentnode|childnode|         src|         dst|
+------+------------+   +-----------+---------+------------+------------+
| leaf2| 17179869184|   |     child1|    leaf4| 25769803776|249108103168|
|child1| 25769803776|   |     child1|   child3| 25769803776| 68719476736|
|child3| 68719476736|   |     child1|    leaf2| 25769803776| 17179869184|
| leaf6|103079215104|   |     child3|    leaf6| 68719476736|103079215104|
|  root|171798691840|   |     child3|    leaf5| 68719476736|214748364800|
| leaf5|214748364800|   |       root|   child1|171798691840| 25769803776|
| leaf4|249108103168|   +-----------+---------+------------+------------+
+------+------------+   

2. Create the GraphFrame

from graphframes import GraphFrame
graph = GraphFrame(vertices, edges)

3. Create UDFs that will form the single parts of the Pregel algorithm

The message type:
from pyspark.sql.types import *
vertColSchema = StructType()\
      .add("dist", DoubleType())\
      .add("node", StringType())\
      .add("path", ArrayType(StringType(), True))

The vertex program:

def vertexProgram(vd, msg):
    if msg == None or vd.__getitem__(0) < msg.__getitem__(0):
        return (vd.__getitem__(0), vd.__getitem__(1), vd.__getitem__(2))
    else:
        return (msg.__getitem__(0), vd.__getitem__(1), msg.__getitem__(2))
vertexProgramUdf = F.udf(vertexProgram, vertColSchema)

The outgoing messages:

def sendMsgToDst(src, dst):
    srcDist = src.__getitem__(0)
    dstDist = dst.__getitem__(0)
    if srcDist < (dstDist - 1):
        return (srcDist + 1, src.__getitem__(1), src.__getitem__(2) + [dst.__getitem__(1)])
    else:
        return None
sendMsgToDstUdf = F.udf(sendMsgToDst, vertColSchema)

Message aggregation:

def aggMsgs(agg):
    shortest_dist = sorted(agg, key=lambda tup: tup[1])[0]
    return (shortest_dist.__getitem__(0), shortest_dist.__getitem__(1), shortest_dist.__getitem__(2))
aggMsgsUdf = F.udf(aggMsgs, vertColSchema)

4. Combine the parts

from graphframes.lib import Pregel
result = graph.pregel.withVertexColumn(colName = "vertCol", \
    initialExpr = F.when(F.col("node")==(F.lit("root")), F.struct(F.lit(0.0), F.col("node"), F.array(F.col("node")))) \
    .otherwise(F.struct(F.lit(float("inf")), F.col("node"), F.array(F.lit("")))).cast(vertColSchema), \
    updateAfterAggMsgsExpr = vertexProgramUdf(F.col("vertCol"), Pregel.msg())) \
    .sendMsgToDst(sendMsgToDstUdf(F.col("src.vertCol"), Pregel.dst("vertCol"))) \
    .aggMsgs(aggMsgsUdf(F.collect_list(Pregel.msg()))) \
    .setMaxIter(10) \
    .setCheckpointInterval(2) \
    .run()
result.select("vertCol.path").show(truncate=False)   

Remarks:

  • maxIter should be set to a value at least as large as the longest path. If the value is higher, the result will stay unchanged, but the computation time becomes longer. If the value is too small, the longer paths will be missing in the result. The current version of GraphFrames (0.8.0) does not support stopping the loop when no more new messages are sent.
  • checkpointInterval should be set to a value smaller than maxIter. The actual value depends on the data and the available hardware. When OutOfMemory exception occur or the Spark session hangs for some time, the value could be reduced.

The final result is a regular dataframe with the content

+-----------------------------+
|path                         |
+-----------------------------+
|[root, child1]               |
|[root, child1, leaf4]        |
|[root, child1, child3]       |
|[root]                       |
|[root, child1, child3, leaf6]|
|[root, child1, child3, leaf5]|
|[root, child1, leaf2]        |
+-----------------------------+

If necessary the non-leaf nodes could be filtered out here.

Marioanzas
  • 1,663
  • 2
  • 10
  • 33
werner
  • 13,518
  • 6
  • 30
  • 45
  • very concise and clean implementation. Just one doubt-How does this checkpoint interval work in this algorithm? – Amardeep Kohli Feb 08 '21 at 20:21
  • The checkpoint parameter controls how often a checkpoint will be created. The effect of a checkpoint here is that it reduces the required amount of memory but increases IO. – werner Feb 10 '21 at 20:29
  • @werner I have useed this as well to get a hierarchy, but instead of the path I would like to get for each node the list of children, how should I modify the message functions, any suggestion? – Matioski Jan 12 '22 at 10:37
  • @Matioski what exactly would be your output? Do I understand you correctly that you are looking for the the next elements of all paths starting from a certain node (eg. root -> [child1], child1 -> [leaf4, child3, leaf2])? – werner Jan 13 '22 at 16:59
  • @werner yes, I have a connected directed tree and for each node we want to get the list of the reachable nodes (eg. child1[leaf4,child3,leaf2,leaf6,leaf5], child3[leaf6,leaf5]). currently what I have done is the following after running the above pregel: – Matioski Jan 27 '22 at 14:20
  • `final = df.select( 'node', F.posexplode_outer('path') ).withColumn( 'children', F.collect_list('col').over(Window.partitionBy('node').orderBy(F.desc('pos'))) ).groupBy('col').agg( F.array_distinct(F.flatten(F.collect_list('children'))).alias('children') ).alias('t1').repartition(1).join( nodes_ident, F.array_contains(F.col('t1.children'), nodes_ident.name) ).groupBy('col').agg( F.collect_list('name').alias('values') ).withColumnRenamed('col', 'id').orderBy('id')` – Matioski Jan 27 '22 at 14:20
  • Is this solution work? I am trying to implement something very similar to @Vardhan's question, but when I implement this solution, I am not getting a full path. In the results, I get only the root element. – Laur Jan 08 '23 at 18:25
  • Curious if we have Python API to graphx yet, couldn't find related information online – user3437212 Jul 26 '23 at 16:58
  • @user3437212 I am not aware of a GraphX-Python-API. According to [this discussion](https://issues.apache.org/jira/browse/SPARK-3789) it will probably not implemented. – werner Jul 28 '23 at 13:10