-1

I am new to Tensorflow and am playing around with some code on github. This code creates a class for the neural network that includes methods to construct the network, formulate the loss function, train the network, perform prediction, etc.

The skeleton code would look something like this:

class NeuralNetwork:
    def __init__(...):

    def initializeNN():

    def trainNN():

    def predictNN():

etc. The neural network is constructed with Tensorflow, hence, the class definition and its methods use tensorflow syntax.

Now in the main part of my script, I create an instance of this class via

model = NeuralNetwork(...)

and use the methods of model such as model.predict to produce plots.

Since training the neural network takes long, I'd like to save the object "model" for future use and with the possibility to call on its methods. I have tried pickle and dill but they both failed. For pickle, I got the error:

TypeError: can't pickle _thread.RLock objects

while for dill, I got:

TypeError: can't pickle SwigPyObject objects

Any suggestions how I can save the object and still be able to invoke its methods? This is essential as I may want to perform prediction on a different set of points in the future.

Thanks!

user1237300
  • 231
  • 2
  • 11
  • 2
    Have you tried using `tf.train.saver`? [Guide](https://www.tensorflow.org/guide/saved_model) – Eric Zhou Jun 06 '19 at 02:39
  • How does this fit into my code above? Do you mind providing me with an example? I don't see how this enables me to access the methods of my object such as the method to perform prediction. I'd like to save the object itself which is "model" above – user1237300 Jun 06 '19 at 02:52
  • 1
    Possible duplicate of [Tensorflow: how to save/restore a model?](https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model) – Billal Begueradj Jun 07 '19 at 15:21

1 Answers1

0

What you should do is the following:

# Build the graph
model = NeuralNetwork(...)
# Create a train saver/loader object
saver = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Train the model in the same way you are doing it currently
    model.train_model()
    # Once you are done training, just save the model definition and it's learned weights
    saver.save(sess, save_path)

And, you are done. Then when you want to use the model again what you can do is:

# Build the graph
model = NeuralNetwork()
# Create a train saver/loader object
loader = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Load the model variables
    loader.restore(sess, save_path)
    # Train the model again for example
    model.train_model()
gorjan
  • 5,405
  • 2
  • 20
  • 40
  • Thanks for the help, gorjan. But does two clarifications: I need to use the extension "ckpt" right? Also, does your example mean that I have to train the neural network again? I tried it and python is complaining that "restoring from checkpoint failed" – user1237300 Jun 08 '19 at 20:13
  • It is not necessary to use the ckpt extension. Also you don't need to train your model again. That's the point of saving and restoring it later. My point was with the `model.train_model()` example that you can do whatever you want with the model after you restore it. – gorjan Jun 08 '19 at 23:33