0

I'd like to get the weight for the tree nodes from a saved (or unsaved) DecisionTreeClassificationModel. However I can't find anything remotely resembling that.

How does the model actually perform the classification not knowing any of those. Below are the Params that are saved in the model:

{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}
10465355
  • 4,481
  • 2
  • 20
  • 44
Jeff Saremi
  • 2,674
  • 3
  • 33
  • 57

1 Answers1

1

By using treeWeights:

treeWeights

Return the weights for each tree

New in version 1.5.0.

So

How does the model actually perform the classification not knowing any of those.

The weights are stored, just not as a part of the metadata. If you have model

from pyspark.ml.classification import RandomForestClassificationModel

model: RandomForestClassificationModel = ...

and save it to disk

path: str = ...

model.save(path)

you'll see that the writer creates treesMetadata subdirectory. If you load the content (default writer uses Parquet):

import os

trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))

you'll see following structure:

trees_metadata.printSchema()
root
 |-- treeID: integer (nullable = true)
 |-- metadata: string (nullable = true)
 |-- weights: double (nullable = true)

where weights column contains the weight of tree identified by treeID.

Similarly node data is stored in the data subdirectory (see for example Extract and Visualize Model Trees from Sparklyr):

spark.read.parquet(os.path.join(path, "data")).printSchema()     
root
 |-- id: integer (nullable = true)
 |-- prediction: double (nullable = true)
 |-- impurity: double (nullable = true)
 |-- impurityStats: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- gain: double (nullable = true)
 |-- leftChild: integer (nullable = true)
 |-- rightChild: integer (nullable = true)
 |-- split: struct (nullable = true)
 |    |-- featureIndex: integer (nullable = true)
 |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- numCategories: integer (nullable = true)

Equivalent information (minus tree data and tree weights) is available for DecisionTreeClassificationModel as well.

10465355
  • 4,481
  • 2
  • 20
  • 44
  • Thanks a lot. The link to the documentation is for RandomForests and not for the basic DecisionTree. However with the approach of saving the model and reading back the data in parquet files I should be able to get to that of the DecisionTree the same way – Jeff Saremi Feb 27 '19 at 17:08
  • Indeed, I focused on tree weights. In such case I don't think this really answers the question - could you unaccept it, so it can be deleted? TIA – 10465355 Feb 27 '19 at 19:37
  • Doesn't matter. Your approach was what I was looking for anyway – Jeff Saremi Feb 27 '19 at 19:47