My problem seems to be very common.
I am doing some reinforcement learning using a vanilla policy gradient method. The environment is just a simple one period game where the state and action spaces are the real line. The agent is a neural network with two output heads that I build manually using dense layers from Keras, e.g. my first hidden layer would be
layers.Dense(NH[0], activation ="relu", \
kernel_initializer=initializers.GlorotNormal())(inputs)
where NH contains a list of number of neurons for hidden layers. The outputs are the mean and standard deviation for my gaussian policy. I don't if this part matters, but I included it nonetheless.
The environment is simple: the state is a normal variable, the action is some real scalar, and there is just one period. I run the policy a bunch of times, collect the resulting batch and use the tools from tf.GradientTape() to update the network on the basis of a custom loss function. I have no problem running that code thousands of times to see the algorithm learn.
The real problem is that I'd like to run the learning process multiple times, each time re-initializing the network weights randomly to have distributions for the history of rewards, but if I run all of this in a loop the computer freezes rapidly. Apparently, this is a very common problem with Keras and Tensorflow, one that people have been complaining about for years and it is still a problem... Now, I have tried the usual solutions. Here, people suggested adding something like the following at the end of the loop so that before I reinitialize the network I get a clean slate.
keras.backend.clear_session()
gc.collect()
del actor
This doesn't solve the problem. Then, I saw someone gave a function that went a little further
def reset_keras(model):
# Clear model, if possible
try:
del model
except:
pass
# Garbage collection
gc.collect()
# Clear and close tensorflow session
session = K.get_session() # Get session
K.clear_session() # Clear session
session.close() # Close session
# Reset all tensorflow graphs
tf.compat.v1.reset_default_graph()
And that doesn't work either. I also tried moving around the order of the first three commands and it doesn't work either...
Anyone has any idea how to solve the problem? It would also be useful to know why this happens. I'd also like to know how to profile memory usage here so that I don't have to wait 4 hours to learn the computer is freezing again with the new solution.
In fact, if you have a minimal working example where you can demonstrate the code doesn't lead to exploding memory use, I would be very much disposed to re-code the whole damn thing from scratch to stop the problem. As a side note, why haven't the developers solve this issue? It's the only package on both R and Python where this has ever happened to me...
EDIT As asked, I provide a minimal working example of the issue. I made up a quick game: it's a moving target where the optimal action is to play some multiple of the state value which yields a reward of 0.
I wrote down an actor class and used a simple linear regression as a critic which may be turned off. If you look at the memory usage, it is climbing... That game won't crash my computer unless I play it a lot more, but it shows that memory usage increases.
import numpy as np
import psutil
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.keras.initializers as initializers
import tensorflow.python.keras.backend as kb
import matplotlib.pyplot as plt
BATCH = 10
MC_DRAWS = 2000
M = 10
# Training options
LR = 0.01
def display_memory():
print( f'{round(psutil.virtual_memory().used/2**30, 2)} GB' )
class Actor:
def __init__(self):
self.nn = self.make_actor()
self.batch = BATCH
self.opt = keras.optimizers.Adam( learning_rate = LR )
def make_actor(self):
inputs = layers.Input( shape=(1) )
hidden = layers.Dense(5, activation='relu',
kernel_initializer=initializers.GlorotNormal() )(inputs)
mu = layers.Dense(1, activation='linear',
kernel_initializer=initializers.GlorotNormal() )(hidden)
sigma = layers.Dense(1, activation='softplus',
kernel_initializer=initializers.GlorotNormal() )(hidden)
nn = keras.Model(inputs=inputs, outputs=[mu, sigma])
return nn
def update_weights(self, state, action, reward):
# Get proper format
state = tf.constant(state, dtype='float32', shape=(self.batch,1))
action = tf.constant(action, dtype='float32', shape=(self.batch,1))
reward = tf.constant(reward, dtype='float32', shape=(self.batch,1))
# Update Policy Network Parameters
with tf.GradientTape() as tape:
# Compute Gaussian loss
loss_value = self.custom_loss(state, action, reward)
loss_value = tf.math.reduce_mean( loss_value, keepdims=True )
# Compute gradients
grads = tape.gradient(loss_value, self.nn.trainable_variables)
# Apply gradients to update network weights
self.opt.apply_gradients(zip(grads, self.nn.trainable_variables))
def custom_loss(self, state, action, reward):
# Obtain mean and standard deviation
nn_mu, nn_sigma = self.nn(state)
# Gaussian pdf
pdf_value = tf.exp(-0.5 *((action - nn_mu) / (nn_sigma))**2) *\
1/(nn_sigma*tf.sqrt(2 *np.pi))
# Log probabilities
log_prob = tf.math.log( pdf_value + 1e-5 )
# Compute loss
loss_actor = -reward * log_prob
return loss_actor
class moving_target_game:
def __init__(self):
self.action_range = [-np.inf, np.inf]
self.state_range = [1, 2]
self.reward_range = [-np.inf, 0]
def draw(self):
return np.random.ranint(low = self.state_range[0],
high = self.state_range[1])
def get_reward(self, action, state):
return -(5*state - action)**2
class Critic:
def __init__(self):
self.order = 3
self.projection = None
def predict(self, state, reward):
# Enforce proper format
x = np.array( state ).reshape(-1,1)
y = np.array( reward ).reshape(-1,1)
# Make regression matrix
X = np.ones( shape = x.shape )
for i in range( self.order ):
X = np.hstack( (X, x**(i+1)) )
# Prediction
xt = x.transpose()
P = x @ np.linalg.inv( xt @ x ) @ xt
Py = P @ y
self.projection = P
return Py
#%% Moving Target Game with Actor and Actor-Critic
do_actor_critic = True
display_memory()
history = np.zeros( shape=(MC_DRAWS, M) )
env = moving_target_game()
for m in range(M):
# New Actor Network
actor = Actor()
if do_actor_critic:
critic = Critic()
for i in range(MC_DRAWS):
state_tape = []
action_tape = []
reward_tape = []
for j in range(BATCH):
# Draw state
state = env.draw()
s = tf.constant([state], dtype='float32')
# Take action
mu, sigma = actor.nn( s )
a = tf.random.normal([1], mean=mu, stddev=sigma)
# Reward
r = env.get_reward( state, a )
# Collect results
action_tape.append( float(a) )
reward_tape.append( float(r) )
state_tape.append( float(state) )
del (s, a, mu, sigma)
# Update network weights
history[i,m] = np.mean( reward_tape )
if do_actor_critic:
# Update critic
value = critic.predict(state_tape, reward_tape)
# Benchmark reward
mod = np.array(reward_tape).reshape(-1,1) - value
# Update actor
actor.update_weights(state_tape, action_tape, mod)
else:
actor.update_weights(state_tape, action_tape, reward_tape)
del actor
kb.clear_session()
if do_actor_critic:
del critic
print( f'Average Reward on last: {np.mean(reward_tape)} ' )
display_memory()
plt.plot( history )