5

I built a random forest model using the following code:

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("features")
val labelConverter = new    IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val training = labelIndexer.transform(df)
val model = rf.fit(training)

now I want to save the model in order to predict later using the following code:

val predictions: DataFrame = model.transform(testData)

I've looked into Spark documentation here and didn't find any option to do that. Any idea? It took me a few hours to build the model , if Spark is crushing I won't be able to get it back.

Yaeli778
  • 215
  • 3
  • 12

4 Answers4

2

It's possible to save and reload tree based models in HDFS using Spark 1.6 using saveAsObjectFile() for both Pipeline based and basic model. Below is example for pipeline based model.

// model
val model = pipeline.fit(trainingData)

// Create rdd using Seq 
sc.parallelize(Seq(model), 1).saveAsObjectFile("hdfs://filepath")

// Reload model by using it's class
// You can get class of object using object.getClass()
val sameModel = sc.objectFile[PipelineModel]("filepath").first()
2

For RandomForestClassifier save & load model: tested spark 1.6.2 + scala in ml(in spark 2.0 you can have direct save option for model)

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier //imports
val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
val model = classifier.fit(trainingData)

sc.parallelize(Seq(model), 1).saveAsObjectFile(modelSavePath) //save model

val linRegModel = sc.objectFile[RandomForestClassificationModel](modelSavePath).first() //load model
`val predictions1 = linRegModel.transform(testData)` //predictions1  is dataframe 
Mentya
  • 191
  • 1
  • 9
1

It is in the MLWriter interface - that is accessed via the writer attribute on your model:

model.asInstanceOf[MLWritable].write.save(path)

Here is the interface:

abstract class MLWriter extends BaseReadWrite with Logging {

  protected var shouldOverwrite: Boolean = false

  /**
   * Saves the ML instances to the input path.
   */
  @Since("1.6.0")
  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  def save(path: String): Unit = {

This is a refactoring from earlier versions of mllib/spark.ml

Update It appears that the Model were not writable:

Exception in thread "main" java.lang.UnsupportedOperationException: Pipeline write will fail on this Pipeline because it contains a stage which does not implement Writable. Non-Writable stage: rfc_4e467607406f of type class org.apache.spark.ml.classification.RandomForestClassificationModel

So there may not be a straightforward solution for this.

WestCoastProjects
  • 58,982
  • 91
  • 316
  • 560
  • this code doesn't work. error: value writer is not a member of org.apache.spark.ml.classification.RandomForestClassificationModel – Yaeli778 May 19 '16 at 13:06
  • @Yaeli778 This feature - as you can see in my answer - `@Since("1.6.0")` - requires a recent version of Spark. If you were on 1.5.X or earlier - then you will not have it. Can you upgrade to 1.6.X ? – WestCoastProjects May 19 '16 at 14:15
  • I run on version 1.6.1 – Yaeli778 May 19 '16 at 14:20
  • @Yaeli778 It seems to be necessary to add this: .asInstanceOf[MLWritable] . OP has been updated to reflect it. – WestCoastProjects May 19 '16 at 15:20
  • It doesn't work either, get the following error: java.lang.ClassCastException: org.apache.spark.ml.classification.RandomForestClassificationModel cannot be cast to org.apache.spark.ml.util.M LWritable – Yaeli778 May 19 '16 at 16:41
  • @Yaeli778 I just realized the post had a typo - it did not precisely match my code. It is `.write.` instead of `.writer.` – WestCoastProjects May 19 '16 at 16:44
  • You can also in the documentation here [https://spark.apache.org/docs/1.6.1/api/java/org/apache/spark/ml/util/MLWritable.html] that it doesn't support Random Forest. – Yaeli778 May 19 '16 at 16:47
  • I already ran the code with the '`write'` command. But I get the following error: `'model.asInstanceOf[MLWritable].write.save("/tmp/model") java.lang.ClassCastException: org.apache.spark.ml.classification.RandomForestClassificationModel cannot be cast to org.apache.spark.ml.util.MLWritable'` – Yaeli778 May 19 '16 at 16:49
  • @Yaeli778 I ran that code: so it does compile and does get submitted to spark successfully. Well it's academic anyways given the lack of support. – WestCoastProjects May 19 '16 at 16:54
  • Thanks for trying. I just don't get it, there is no way to save the model and use it later? It doesn't make any sense – Yaeli778 May 19 '16 at 16:58
  • @Yaeli778 Likewise surprised at the lack of support. Sorry could not be of more help. – WestCoastProjects May 19 '16 at 17:01
1

Here is a PySpark v1.6 implementation corresponding to the Scala saveAsObjectFile() answer above.

It coerses the Python objects to/from Java objects to achieve serialisation with saveAsObjectFile().

Without the Java coersion I had weird Py4J errors on serialisation. If anyone has a simplier implementation, please edit or comment.

Save a trained RandomForestClassificationModel object:

# Save RandomForestClassificationModel to hdfs
gateway = sc._gateway
java_list = gateway.jvm.java.util.ArrayList()
java_list.add(rfModel._java_obj)
modelRdd = sc._jsc.parallelize(java_list)
modelRdd.saveAsObjectFile("hdfs:///some/path/rfModel")

Load a trained RandomForestClassificationModel object:

# Load RandomForestClassificationModel from hdfs
rfObjectFileLoaded = sc._jsc.objectFile("hdfs:///some/path/rfModel")
rfModelLoaded_JavaObject = rfObjectFileLoaded.first()
rfModelLoaded = RandomForestClassificationModel(rfModelLoaded_JavaObject)
predictions = rfModelLoaded.transform(test_input_df)
Dylan Hogg
  • 3,118
  • 29
  • 26