49

I'm slightly confused in regard to how I save a trained classifier. As in, re-training a classifier each time I want to use it is obviously really bad and slow, how do I save it and the load it again when I need it? Code is below, thanks in advance for your help. I'm using Python with NLTK Naive Bayes Classifier.

classifier = nltk.NaiveBayesClassifier.train(training_set)
# look inside the classifier train method in the source code of the NLTK library

def train(labeled_featuresets, estimator=nltk.probability.ELEProbDist):
    # Create the P(label) distribution
    label_probdist = estimator(label_freqdist)
    # Create the P(fval|label, fname) distribution
    feature_probdist = {}
    return NaiveBayesClassifier(label_probdist, feature_probdist)
merv
  • 67,214
  • 13
  • 180
  • 245
  • 2
    Are you asking for some kind of persistence strategy? As in save to DB, file and load again? You could just pickle the data and load it again later. – EdChum Apr 04 '12 at 18:33

3 Answers3

94

To save:

import pickle
f = open('my_classifier.pickle', 'wb')
pickle.dump(classifier, f)
f.close()

To load later:

import pickle
f = open('my_classifier.pickle', 'rb')
classifier = pickle.load(f)
f.close()
Community
  • 1
  • 1
Jacob
  • 4,204
  • 1
  • 25
  • 25
  • how would I go about it if I want to retrain my model using an already pickled model? import pickle f = open('my_classifier.pickle', 'rb') classifier = pickle.load(f) .....then? – Mohsin May 28 '17 at 11:51
  • I get `TypeError: can't pickle module objects` – jbuddy_13 May 15 '20 at 02:19
5

I went thru the same problem, and you cannot save the object since is a ELEFreqDistr NLTK class. Anyhow NLTK is hell slow. Training took 45 mins on a decent set and I decided to implement my own version of the algorithm (run it with pypy or rename it .pyx and install cython). It takes about 3 minutes with the same set and it can simply save data as json (I'll implement pickle which is faster/better).

I started a simple github project, check out the code here

luke14free
  • 2,529
  • 1
  • 17
  • 25
1

To Retrain the Pickled Classifer :

f = open('originalnaivebayes5k.pickle','rb')
classifier = pickle.load(f)
classifier.train(training_set)
print('Accuracy:',nltk.classify.accuracy(classifier,testing_set)*100)
f.close()
Michael Yadidya
  • 1,397
  • 1
  • 9
  • 15