0

So I'm training a xgboost with hyperparameter tuning. So my code snippet looks something like this:

val paramGrid = new ParamGridBuilder().
    addGrid(booster.minChildWeight, Array(0.3,0.6,0.7, 0.8)).
    addGrid(booster.eta, Array(0.1,0.2,0.4, 0.6)).
    build()


val cv = new CrossValidator().
    setEstimator(pipeline).
    setEvaluator(evaluator).
    setEstimatorParamMaps(paramGrid).
    setNumFolds(10)

val cvModel = cv.fit(df)

val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(1).
    asInstanceOf[XGBoostClassificationModel]

Now I want to save the parameter map to a txt and parse it later. However when I'm trying to export it into the text file with something like this:

bestModel.extractParamMap()

val file = new File("/home/hadoop/test/hyper_params.txt")
val bw = new BufferedWriter(new FileWriter(file))
bw.write(bestModel.extractParamMap())
bw.close()

I'm getting the following error:

error: overloaded method value write with alternatives:
  (x$1: Int)Unit <and>
  (x$1: String)Unit <and>
  (x$1: Array[Char])Unit
 cannot be applied to (org.apache.spark.ml.param.ParamMap)
       bw.write(bestModel.extractParamMap())

I'm pretty new to scala and haven't been able to find any solution on how to save the parameter map to a .txt file. This is the first step of my problem.

Next I want to create some variables, where in I want to read the saved parameters value from the .txt file.

Say something like this:

val min_child_weight=('../param.txt){key value here}

So how can I do it? I've gone through some posts like this and this, but haven't been able to figure out the code for my purpose.

Debadri Dutta
  • 1,183
  • 1
  • 13
  • 39

1 Answers1

0

First, you don't save stuff in Spark to your local file system using a regular BufferedWriter. Typically for Dataframes and RDDs, you would use a Spark API and prefix the path the "file:///" as shown here - How to save Spark RDD to local filesystem. Also, you would use the MLWriter for what you are doing, and you would save the whole Pipeline like so - https://jaceklaskowski.gitbooks.io/mastering-apache-spark/spark-mllib/spark-mllib-pipelines-persistence.html.

UPDATED:

spark
 .sparkContext
 .parallelize(List(bestModel.extractParamMap().toString))
 .saveAsTextFile("file:///home/hadoop/test/hyper_params.txt")
uh_big_mike_boi
  • 3,350
  • 4
  • 33
  • 64