1

Keras has a callback that reduces the learning rate upon a plateauing of a specified metric, called ReduceLROnPlateau.

How do you create such a feature in native Tensorflow? In a Tensorflow model, is it possible to call on Keras callbacks? Or does it need to be written in native Tensorflow? If so, how would you set the learning rate in the middle of a training session?

Maxim
  • 52,561
  • 27
  • 155
  • 209
mikal94305
  • 4,663
  • 8
  • 31
  • 40

2 Answers2

4

I'm afraid tensorflow doesn't support this out-of-the-box (and keras callbacks aren't directly applicable neither). Here's the list of supported learning rate scheduling techniques: all of them are different algorithms, but are self-contained, i.e. independent from the training performance.

But the good news is that all optimizers accept the tensor for the learning rate. So you can create a variable or a placeholder for the learning rate and change its value based on validation performance (which you'll also need to calculate yourself). Here's an example from this wonderful answer:

learning_rate = tf.placeholder(tf.float32, shape=[])
# ...
train_step = tf.train.GradientDescentOptimizer(
    learning_rate=learning_rate).minimize(mse)

sess = tf.Session()

# Feed different values for learning rate to each training step.
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.01})
sess.run(train_step, feed_dict={learning_rate: 0.01})
Maxim
  • 52,561
  • 27
  • 155
  • 209
0

Here's a not 1:1 conversion from the Keras 'ReduceLROnPlateau' I wrote up. It examines each batch's loss instead of sampling randomly at the end of each epoch. Cooldown & patience are still in terms of epoch though. It can be used just like tf.train.exponential_decay(...).

I think there's probably a better way to go about it than simply monitoring the minimum loss value, as the minimum value could be an extreme outlier. A metric in terms of some running average of the loss gradient might be better.

def plateau_decay(learning_rate, global_step, loss, data_count, batch_size, factor=0.1, patience=10, min_delta=1e-4, cooldown=0, min_lr=0):
steps_per_epoch = math.ceil(data_count // batch_size)
patient_steps = patience * steps_per_epoch
cooldown_steps = cooldown * steps_per_epoch

if not isinstance(learning_rate, tf.Tensor):
    learning_rate = tf.get_variable('learning_rate', initializer=tf.constant(learning_rate), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])

with tf.variable_scope('plateau_decay'):
    step = tf.get_variable('step', trainable=False, initializer=global_step, collections=[tf.GraphKeys.LOCAL_VARIABLES])
    best = tf.get_variable('best', trainable=False, initializer=tf.constant(np.Inf, tf.float32), collections=[tf.GraphKeys.LOCAL_VARIABLES])

    def _update_best():
        with tf.control_dependencies([
            tf.assign(best, loss),
            tf.assign(step, global_step),
            tf.print('Plateau Decay: Updated Best - Step:', global_step, 'Next Decay Step:', global_step + patient_steps, 'Loss:', loss)
        ]):
            return tf.identity(learning_rate)

    def _decay():
        with tf.control_dependencies([
            tf.assign(best, loss),
            tf.assign(learning_rate, tf.maximum(tf.multiply(learning_rate, factor), min_lr)),
            tf.assign(step, global_step + cooldown_steps),
            tf.print('Plateau Decay: Decayed LR - Step:', global_step, 'Next Decay Step:', global_step + cooldown_steps + patient_steps, 'Learning Rate:', learning_rate)
        ]):
            return tf.identity(learning_rate)

    def _no_op(): return tf.identity(learning_rate)

    met_threshold = tf.less(loss, best - min_delta)
    should_decay = tf.greater_equal(global_step - step, patient_steps)

    return tf.cond(met_threshold, _update_best, lambda: tf.cond(should_decay, _decay, _no_op))
sir_cakes
  • 13
  • 3