7

I've tried to use a Random Forest model in order to predict a stream of examples, but it appears that I cannot use that model to classify the examples. Here is the code used in pyspark:

sc = SparkContext(appName="App")

model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, impurity='gini', numTrees=150)


ssc = StreamingContext(sc, 1)
lines = ssc.socketTextStream(hostname, int(port))

parsedLines = lines.map(parse)
parsedLines.pprint()

predictions = parsedLines.map(lambda event: model.predict(event.features))

and the error returned while compiling it in the cluster:

  Error : "It appears that you are attempting to reference SparkContext from a broadcast "
    Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

is there a way to use a modèle generated from a static data to predict a streaming examples ?

Thanks guys i really appreciate it !!!!

zero323
  • 322,348
  • 103
  • 959
  • 935
testing
  • 183
  • 1
  • 2
  • 6

1 Answers1

4

Yes, you can use model generated from static data. The problem you experience is not related to streaming at all. You simply cannot use JVM based model inside action or transformations (see How to use Java/Scala function from an action or a transformation? for an explanation why). Instead you should apply predict method to a complete RDD for example using transform on DStream:

from pyspark.mllib.tree import RandomForest
from pyspark.mllib.util import MLUtils
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
from operator import attrgetter


sc = SparkContext("local[2]", "foo")
ssc = StreamingContext(sc, 1)

data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
trainingData, testData = data.randomSplit([0.7, 0.3])

model = RandomForest.trainClassifier(
    trainingData, numClasses=2, nmTrees=3
)

(ssc
    .queueStream([testData])
    # Extract features
    .map(attrgetter("features"))
    # Predict 
    .transform(lambda _, rdd: model.predict(rdd))
    .pprint())

ssc.start()
ssc.awaitTerminationOrTimeout(10)
Community
  • 1
  • 1
zero323
  • 322,348
  • 103
  • 959
  • 935