I made use of checkpoint to save the weights of my module using this code:
chkp_path = "./policies/deep_cfr"
checkpoint = tf.train.Checkpoint(model=deep_cfr_solver._policy_network)
checkpoint.write(chkp_path)
Where deep_cfr_solver._policy_network
is a subclass of tf.Module
Afterwards I try to load this module again in another function:
policy_network = simple_nets.MLP(input_size=4498, hidden_sizes=[8, 8], output_size=4)
new_checkpoint = tf.train.Checkpoint(model=policy_network)
new_checkpoint.restore("./policies/deep_cfr")
info = state.information_state_tensor(0)
tensor = tf.convert_to_tensor(info)
tensor = tf.reshape(tensor, shape=[1, tensor.shape[0]])
action_logits = policy_network(tensor)
action_probs = tf.nn.softmax(action_logits)
with tf.compat.v1.Session() as sess: print(action_probs.eval())
Where state
is a GameState of OpenSpiel.
But the last line of this code results in the following error:
Exception has occurred: FailedPreconditionError Attempting to use uninitialized value mlp/weights [[node mlp/weights/read (defined at /open_spiel/open_spiel/python/simple_nets.py:48) ]]
The code I use is based on this documentation