Here's a not 1:1 conversion from the Keras 'ReduceLROnPlateau' I wrote up. It examines each batch's loss instead of sampling randomly at the end of each epoch. Cooldown & patience are still in terms of epoch though. It can be used just like tf.train.exponential_decay(...).
I think there's probably a better way to go about it than simply monitoring the minimum loss value, as the minimum value could be an extreme outlier. A metric in terms of some running average of the loss gradient might be better.
def plateau_decay(learning_rate, global_step, loss, data_count, batch_size, factor=0.1, patience=10, min_delta=1e-4, cooldown=0, min_lr=0):
steps_per_epoch = math.ceil(data_count // batch_size)
patient_steps = patience * steps_per_epoch
cooldown_steps = cooldown * steps_per_epoch
if not isinstance(learning_rate, tf.Tensor):
learning_rate = tf.get_variable('learning_rate', initializer=tf.constant(learning_rate), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
with tf.variable_scope('plateau_decay'):
step = tf.get_variable('step', trainable=False, initializer=global_step, collections=[tf.GraphKeys.LOCAL_VARIABLES])
best = tf.get_variable('best', trainable=False, initializer=tf.constant(np.Inf, tf.float32), collections=[tf.GraphKeys.LOCAL_VARIABLES])
def _update_best():
with tf.control_dependencies([
tf.assign(best, loss),
tf.assign(step, global_step),
tf.print('Plateau Decay: Updated Best - Step:', global_step, 'Next Decay Step:', global_step + patient_steps, 'Loss:', loss)
]):
return tf.identity(learning_rate)
def _decay():
with tf.control_dependencies([
tf.assign(best, loss),
tf.assign(learning_rate, tf.maximum(tf.multiply(learning_rate, factor), min_lr)),
tf.assign(step, global_step + cooldown_steps),
tf.print('Plateau Decay: Decayed LR - Step:', global_step, 'Next Decay Step:', global_step + cooldown_steps + patient_steps, 'Learning Rate:', learning_rate)
]):
return tf.identity(learning_rate)
def _no_op(): return tf.identity(learning_rate)
met_threshold = tf.less(loss, best - min_delta)
should_decay = tf.greater_equal(global_step - step, patient_steps)
return tf.cond(met_threshold, _update_best, lambda: tf.cond(should_decay, _decay, _no_op))