I am trying to implement an easily understandable version of the CoinDICE Algorithm from the Google DeepMind Paper CoinDICE: Off-Policy Confidence Interval Estimation, so the Reinforcement Learning community can use it easily for their own applications. Since the theory behind the algorithm is complicated and the GitHub Repository has code which is hard to understand and does not connect well to the theory in the paper, I think that a lot of people would benefit from this.
I have already implemented everything one would need theoretically and reviewed the code with some of my peers. Nevertheless, there still seems to be a problem with generating the correct results. I will only describe my problem here without explaining about the CoinDICE-Paper above.
As my application I chose the CartPole-v1
-environment and trained behavior and evaluation policies with the A2C
-algorithm from stable_baselines3
for 5000 and 10000 steps respectively.
import gymnasium as gym
from stable_baselines3 import A2C
from cv2 import destroyAllWindows
env = gym.make("CartPole-v1", render_mode="rgb_array")
models = {key: A2C("MlpPolicy", env, verbose=0) for key in ["b", "e"]}
models["b"].learn(total_timesteps=5_000)
models["e"].learn(total_timesteps=10_000)
vec_env = models["b"].get_env()
Their policy values should lie around 70-100. However, my CoinDICE class yields an upper bound on the policy values which are far below what I would expect. Each time a batch gets processed, the following stats are recorded:
L
: the approximation of the upper bound of the policy value,grads_Q
: the length of the gradient of L with respect to the Q-network,grads_zeta
: the same as grad_Q just with the zeta-network,fun
: the error between the KL-divergence and its required bound,D
: the KL-divergence,lamda
: the Lagrange multiplier to bound the KL-divergence (lambda is a Python key word),ell_mean
&ell_std
: the mean and standard deviation of the sample of the random vector ell from the paper,w_mean
&w_std
: the mean and standard deviation of the output of (37) from the paper for w;
Using a function plot_stats(stats_upper, decimals=4)
, I arrive at the following plot. (stats_upper
gets calculated below.)
stats over each processed batch
I divided my code into a SamplerCartpole
- and CoinDICE
-class, whose instantiations you can find in the following two blocks of code.
gamma = 0.99
T = float("inf")
n_batches = 256
batch_size = 128
sampler = SamplerCartpole(vec_env, models, gamma, T, n_batches, batch_size)
input_shapes = {"s": (4,), "a": (1,)}
Q_net = get_model(input_shapes)
zeta_net = get_model(input_shapes, last_activation="softplus")
lr = 0.0001
Q_opt = Adam(learning_rate=lr)
zeta_opt = Adam(learning_rate=lr)
alpha = 0.05
n_epochs = 8
lamda_init = 1
coindice = CoinDICE(sampler, alpha, n_epochs, Q_net, Q_opt, zeta_net, zeta_opt, lamda_init)
L, stats_upper = coindice.upper(record_stats=True)
Since sampling is specific to the environment and policy format, I have a general Sampler
-class, which gets inherited by the more specific SamplerCartpole
.
class Sampler:
"""
connects the environment (MDP) and model(s) (policies) to the algorithm
"""
def __init__(self, env, models, gamma, T, n_batches, batch_size):
self.env = env
self.gamma = gamma
self.T = T
self.n_batches = n_batches
self.batch_size = batch_size
self.models = models
self.i_batch = 0
self.get_batches()
def preprocess_state(self, s):
raise NotImplementedError("method must be overwritten")
def preprocess_action(self, a):
raise NotImplementedError("method must be overwritten")
def preprocess_reward(self, r):
raise NotImplementedError("method must be overwritten")
def get_trajectory(self):
states = []
actions = []
rewards = []
s = self.env.reset()
s = self.preprocess_state(s)
states.append(s)
done = False
t = 0
while not done and t < self.T:
a = self.get_action(s)
s, r, done, _ = self.env.step(a)
a = self.preprocess_action(a)
s = self.preprocess_state(s)
r = self.preprocess_reward(r)
actions.append(a)
states.append(s)
rewards.append(r)
t += 1
length = t
return states, actions, rewards, length
def get_batches(self):
trajectories = []
lengths = []
for _ in tqdm(range(self.batch_size), desc="getting trajectories"):
states, actions, rewards, length = self.get_trajectory()
trajectories.append([states, actions, rewards])
lengths.append(length)
self.batches = [{"s_0": [], "s": [], "a": [], "r": [], "s_prime": []} for _ in range(self.n_batches)]
for batch in self.batches:
for trajectory, length in zip(trajectories, lengths):
states, actions, rewards = trajectory
i = np.random.randint(1, length-1)
keys = ["s_0", "s", "a", "r", "s_prime"]
values = [states[0], states[i], actions[i], rewards[i], states[i+1]]
for key, value in zip(keys, values):
batch[key].append(value)
for key, value in batch.items():
batch[key] = tf.concat(value, axis=0)
def get_batch(self):
batch = self.batches[self.i_batch]
self.i_batch += 1
self.i_batch %= self.n_batches
return batch["s_0"], batch["s"], batch["a"], batch["r"], batch["s_prime"]
@property
def action_space(self):
raise NotImplementedError("method must be overwritten")
def get_probabilities(self, s):
raise NotImplementedError("method must be overwritten")
def get_action(self, s):
raise NotImplementedError("method must be overwritten")
from sampling import Sampler
class SamplerCartpole(Sampler):
def preprocess_state(self, s):
s = tf.convert_to_tensor(s)
s = tf.squeeze(s)
s = tf.expand_dims(s, axis=0)
return s
def preprocess_action(self, a):
a = int(a)
a = tf.convert_to_tensor(a)
a = tf.expand_dims(a, axis=0)
a = tf.expand_dims(a, axis=0)
return a
def preprocess_reward(self, r):
r = float(r)
r = tf.convert_to_tensor(r)
r = tf.expand_dims(r, axis=0)
r = tf.cast(r, dtype=tf.float32)
return r
@property
def action_space(self):
return [tf.convert_to_tensor([i]) for i in range(1)]
def get_probabilities(self, s):
s = s.numpy()
# https://stackoverflow.com/a/70012691/16192280
s = obs_as_tensor(s, self.models["e"].policy.device)
p = self.models["e"].policy.get_distribution(s).distribution.probs.detach().numpy()
p = np.squeeze(p).T
return p
def get_action(self, s):
a, _ = self.models["b"].predict(s, deterministic=True)
return a
Finally, the CoinDICE
-class contains the algorithm itself, which is based on Algorithm 1 from the source paper.
Also, the get_model
-function builds the two networks necessary for the algorithm to work.
# ---------------------------------------------------------------- #
import tensorflow as tf
import numpy as np
from scipy.stats import chi2
from scipy.optimize import minimize
from tqdm import tqdm
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Dropout, Concatenate
import matplotlib.pyplot as plt
# ---------------------------------------------------------------- #
def get_model(input_shapes, last_activation="linear"):
"""
input_shapes ... dict with keys "s" and "a" and values as input shape tuples for states and actions respectively
"""
s_in = Input(shape=input_shapes["s"])
s_out = Dense(16, activation="sigmoid")(s_in)
a_in = Input(shape=input_shapes["a"])
a_out = Dense(1, activation="sigmoid")(a_in)
s_a_in = Concatenate()([s_out, a_out])
x = Dense(16, activation="relu")(s_a_in)
# x = Dropout(0.1)(x)
# x = Dense(32, activation="relu")(x)
x = Dropout(0.1)(x)
x = Dense(16, activation="relu")(x)
x = Dropout(0.1)(x)
s_a_out = Dense(1, activation=last_activation)(x)
model = Model(inputs=[s_in, a_in], outputs=s_a_out)
model.compile()
return model
# ---------------------------------------------------------------- #
class CoinDICE:
def __init__(self, sampler, alpha, n_epochs, Q_net, Q_opt, zeta_net, zeta_opt, lamda_init):
"""
sampler : samples data from environment according to behavior policy \n
policy_wrapper : for getting expected network output for state-action values for fixed state over actions \n
gamma : discount factor \n
alpha : 1 - alpha is the conficence \n
n_epochs : # epochs \n
Q_net : neural network for Q (see pseudocode) \n
Q_opt : optimizer for Q_net (see pseudocode) \n
zeta_net : neural network for zeta (see pseudocode), output should be >= 0 \n
zeta_opt : optimizer for zeta_net (see pseudocode) \n
eta : learning rate \n
"""
self.sampler = sampler
self.alpha = alpha
self.n_epochs = n_epochs
# set divergence limit
self.xi = 1/2 * chi2.ppf(1-self.alpha, df=1)
self.Q_net = Q_net
self.Q_opt = Q_opt
self.zeta_net = zeta_net
self.zeta_opt = zeta_opt
assert lamda_init > 0
self.lamda_init = lamda_init
@property
def gamma(self):
return self.sampler.gamma
@property
def n(self):
return self.sampler.batch_size
@property
def D_bound(self):
return self.xi / self.n
@property
def n_batches(self):
return self.sampler.n_batches
@property
def action_space(self):
return self.sampler.action_space
def get_probabilities(self, s):
return self.sampler.get_probabilities(s)
def get_value(self, net, s, a=None):
if a is None: # take expectation over a
actions = self.action_space
probabilities = self.get_probabilities(s)
expected_value = tf.zeros(len(s))
for a, p in zip(actions, probabilities):
a = tf.expand_dims(a, axis=0)
a = tf.tile(a, [len(s), 1])
expected_value += tf.squeeze(net([s, a])) * p
return expected_value
else:
return tf.squeeze(net([s, a]))
def conficence_bound(self, sign, record_stats=False):
lamda = self.lamda_init
Q_net = tf.keras.models.clone_model(self.Q_net)
Q_opt = tf.keras.models.clone_model(self.Q_opt)
zeta_net = tf.keras.models.clone_model(self.zeta_net)
zeta_opt = tf.keras.models.clone_model(self.zeta_opt)
w = tf.ones(self.n) / self.n
stat_names = ["L", "D", "lamda", "fun", "grads_Q", "grads_zeta"] + ["w_mean", "w_std", "ell_mean", "ell_std"]
stats = {stat_name: np.zeros([self.n_epochs, self.n_batches]) for stat_name in stat_names}
for i_epoch in range(self.n_epochs):
pbar = tqdm(range(self.n_batches), desc=f"epoch {i_epoch+1}/{self.n_epochs}")
for i_batch in pbar:
# Sample from target policy a_0^(j) ~ pi(s_0^(j)), a_prime^(j) ~ pi(s_prime^(j)) for j = 1, ..., n.
s_0, s, a, r, s_prime = self.sampler.get_batch()
# ---------------------------------------------------------------- #
# Compute loss terms:
# ell^(j) := (1 - gamma) Q_{theta_1}(s_0^(j), a_0^(j)) + zeta_{theta_2}(s^(j), a^(j)) * (-Q_{theta_1}(s^(j), a^(j)) + r^(j) + gamma Q_{theta_1}(s_prime^(j), a_prime^(j)))
with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
tape.watch(Q_net.trainable_weights)
tape.watch(zeta_net.trainable_weights)
Q_0 = self.get_value(Q_net, s_0)
Q = self.get_value(Q_net, s, a)
Q_prime = self.get_value(Q_net, s_prime)
zeta = self.get_value(zeta_net, s, a)
A = (1 - self.gamma) * Q_0
B = zeta * (r + self.gamma * Q_prime - Q)
ell = A + B
stats["ell_mean"][i_epoch, i_batch] = np.mean(ell)
stats["ell_std"][i_epoch, i_batch] = np.std(ell)
# Compute loss L := sum_{j=1}^n w^(j) * ell^(j)
L = tf.reduce_sum(ell * w, axis=0)
stats["L"][i_epoch, i_batch] = float(L)
loss_Q = L
loss_zeta = -tf.reduce_sum(B * w, axis=0) # -L
# ---------------------------------------------------------------- #
# Update (theta_1, theta_2) <- OPT_theta(L, theta_1, theta_2)
for loss, net, opt, key in zip([loss_Q, loss_zeta], [Q_net, zeta_net], [Q_opt, zeta_opt], ["grads_Q", "grads_zeta"]):
grads = tape.gradient(loss, net.trainable_weights)
opt.apply_gradients(zip(grads, net.trainable_weights))
stats[key][i_epoch, i_batch] = tf.norm([tf.norm(grad) for grad in grads])
del tape
# ---------------------------------------------------------------- #
# Update (w, lamda) by (37)
def get_w(lamda):
w = tf.math.exp(ell / lamda)
w /= tf.reduce_sum(w, axis=0)
return w
def get_D(w):
# D = tf.reduce_sum(w * tf.math.log(w), axis=0) # paper
D = tf.reduce_sum(w * tf.math.log(w * self.n), axis=0) # mine
D = float(D)
return D
def fun(x):
w = get_w(x)
D = get_D(w)
error = abs(D - self.D_bound)
return error
# x0 = lamda
# solution = minimize(fun=fun, x0=x0, bounds=[(0, None)])#, options={"maxiter": 100}, method="BFGS")
# lamda = float(solution["x"])
# assert abs(lamda) > 1e-9
lamdas = np.linspace(max(lamda-1, 0.1), lamda+1, 100)
funs = [fun(lamda) for lamda in lamdas]
i = np.nanargmin(funs)
lamda = lamdas[i]
w = get_w(lamda)
# print(f"{i=}; {lamda=}")
# if i_batch % 10 == 0:
# plt.figure(figsize=(8, 2))
# plt.plot(lamdas, funs)
# plt.axvline(lamda, linestyle=":", color="black")
# # d = 0.1
# # plt.xlim([lamda-d, lamda+d])
# plt.show()
stats["w_mean"][i_epoch, i_batch] = np.mean(w)
stats["w_std"][i_epoch, i_batch] = np.std(w)
stats["lamda"][i_epoch, i_batch] = lamda
stats["D"][i_epoch, i_batch] = get_D(w)
stats["fun"][i_epoch, i_batch] = fun(lamda)
# ---------------------------------------------------------------- #
pbar.set_postfix({key: stat[i_epoch, i_batch] for key, stat in stats.items()})
bound = L / (1 - self.gamma)
return (bound, stats) if record_stats else bound
def upper(self, record_stats=None):
sign = 1
return self.conficence_bound(sign, record_stats)
def lower(self, record_stats=None):
raise NotImplementedError
sign = -1
return self.conficence_bound(sign, record_stats)
# ---------------------------------------------------------------- #