One of the issues I've run into with Apache Spark, is visualizing Decision Trees.
I can produce a tree using DecisionTree.trainClassifier
.
and I can get some rudimentary output using :
print(model.toDebugString())
But ideally, the current output:
If (feature 0 <= -35.0)
If (feature 24 <= 176.0)
Predict: 2.1
If (feature 24 = 176.0)
Predict: 4.2
Else (feature 24 > 176.0)
Predict: 6.3
Else (feature 0 > -35.0)
If (feature 24 <= 11.0)
Predict: 4.5
Else (feature 24 > 11.0)
Predict: 10.2
could be output as JSON, or something parseable, so that we could layer in a D3 Visualization library. Using the example above...
{
"node": [
{
"name":"node1",
"rule":"feature 0 <= -35.0",
"children":[
{
"name":"node2",
"rule":"feature 24 <= 176.0",
"children":[
{
"name":"node4",
"rule":"feature 20 < 116.0",
"predict": 2.1
},
{
"name":"node5",
"rule":"feature 20 = 116.0",
"predict": 4.2
},
{
"name":"node5",
"rule":"feature 20 > 116.0",
"predict": 6.3
}
]
},
{
"name":"node3",
"rule":"feature 0 > -35.0",
"children":[
{
"name":"node7",
"rule":"feature 3 <= 11.0",
"predict": 4.5
},
{
"name":"node8",
"rule":"feature 3 > 11.0",
"predict": 10.2
}
]
}
]
}
]
}