0

I am using optimizer.get_config() to get the final state of my adam optimizer (as in https://stackoverflow.com/a/60077159/607528) however .get_config() is returning the initial state. I assume this means one of the following

  1. .get_config() is supposed to return the initial state
  2. my optimizer is not updating because I've set something up wrong
  3. my optimizer is not updating tf's adam is broken (highly unlikely)
  4. my optimizer is updating but is being reset somewhere before I call .get_config()
  5. something else?

Of course I originally noticed the issue in a proper project with training and validation sets etc, but here is a really simple snippet that seems to reproduce the issue:

import tensorflow as tf
import numpy as np

x=np.random.rand(100)
y=(x*3).round()

model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x, y, epochs=500)
model.evaluate(x, y)

model.optimizer.get_config()
brook
  • 247
  • 2
  • 15

1 Answers1

0

If you want to restore your training - you should save optimizer weights, not config:

    weight_values = optimizer.get_weights()
    with open(self.output_path+'optimizer.pkl', 'wb') as f:
        pickle.dump(weight_values, f)

And then load them:

    model.fit(dummy_x, dummy_y, epochs=500) # build optimizer by fitting model with dummy input - e.g. random tensors with simpliest shape
    with open(self.path_to_saved_model+'optimizer.pkl', 'rb') as f:
        weight_values = pickle.load(f)
    optimizer.set_weights(weight_values)
Andrey
  • 5,932
  • 3
  • 17
  • 35
  • Thanks @andrey: This answer was so clearly correct I accepted this before testing. Sadly I'm its lead me to a new problem `ValueError: You called set_weights(weights) on optimizer Adam with a weight list of length 255, but the optimizer was expecting 0 weights.` Following https://stackoverflow.com/a/49504376/607528 I tried calling `_make_train_function`, which doesn't exist in TF2.3 - however `make_train_function` does. But still `model.make_train_function(); model.optimizer.set_weights(weight_values)` doesn't fix the problem. Thoughts? – brook Feb 11 '21 at 17:35
  • @brook you have to build optimizer before loading weights - see the updated answer – Andrey Feb 11 '21 at 18:06