I'm doing some Q-learning with the simple_rl
library. I've trained a QLearningAgent
and am trying to inspect the q-table to see what strategy the agent arrives at.
The q-table (which is a defaultdict
) is much larger than I would have expected. The game I am training the agent on only has 16 different states. However, the q-table contained over 3,000 entries.
To debug I checked len(set(agent.q_func.keys())) < len(list(agent.q_func.keys()))
and found that this was True
.
The keys are objects of a custom State
class, but I checked the hashes of the seemingly-duplicated ones and they were the same as each other as you would expect, so I don't think it is that they're not hashing properly.
For reference, these are how the methods are defined within the class:
def __eq__(self, other):
if isinstance(other, State):
return self.data == other.data
return False
def __hash__(self):
if type(self.data).__module__ == np.__name__:
# Numpy arrays
return hash(str(self.data))
elif self.data.__hash__ is None:
return hash(tuple(self.data))
else:
return hash(self.data)
I know that dictionaries don't support duplicate keys, so what actually is going on?
EDIT: add (minimally) reproducible example
import random
from simple_rl.mdp.MDPClass import MDP
from simple_rl.mdp.StateClass import State
from simple_rl.agents import QLearningAgent
from collections import defaultdict
class Toy:
def __init__(self):
self.buttons = [random.randint(0, 1) for _ in range(4)]
def __eq__(self, other):
return isinstance(other, Toy) and \
self.buttons == other.buttons
def __hash__(self):
return hash(tuple(self.buttons))
def spin(self):
steps = random.randint(0, 3)
self.buttons = self.buttons[steps:] + self.buttons[:steps]
def press(self, to_press):
for button in to_press:
self.buttons[button] = 1 - self.buttons[button]
self.spin()
def is_solved(self):
return all(button == 0 for button in self.buttons)
class ToyMDP(MDP):
ACTIONS = ["1", "2", "3", "4"]
def __init__(self):
# Setup init state.
self.toy = Toy()
init_state = State(data=self.toy.buttons)
MDP.__init__(self, ToyMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state)
def get_parameters(self):
param_dict = defaultdict(int)
param_dict["num_buttons"] = 4
return param_dict
def _reward_func(self, state, action, next_state):
if state.is_terminal():
return 0
return int(next_state.is_terminal())
def _transition_func(self, state, action):
# Get buttons to press from action name.
to_press = list(range(int(action)))
# Make new state.
self.toy.press(to_press)
new_state = State(self.toy.buttons)
# Set terminal.
if self._is_goal_state(state):
new_state.set_terminal(True)
return new_state
def _is_goal_state(self, state):
return sum(state) == 0
def __str__(self):
return 'toy'
def train_agent(agent, mdp, num_steps=1000, num_episodes=10):
for i in range(num_episodes):
state = mdp.get_init_state()
reward = 0
for _ in range(num_steps):
# Compute the agent's policy.
action = agent.act(state, reward)
# Terminal check.
if state.is_terminal():
break
# Execute in MDP.
reward, next_state = mdp.execute_agent_action(action)
# Update pointer.
state = next_state
agent.end_of_episode()
mdp.reset()
mdp = ToyMDP()
agent = QLearningAgent(actions=mdp.get_actions())
train_agent(agent, mdp)
Then run the following to find the duplicates:
q = set()
for k in agent.q_func:
q.add(k) # or q.add(hash(k))
print(len(q), len(agent.q_func))