0

I'm doing some Q-learning with the simple_rl library. I've trained a QLearningAgent and am trying to inspect the q-table to see what strategy the agent arrives at.

The q-table (which is a defaultdict) is much larger than I would have expected. The game I am training the agent on only has 16 different states. However, the q-table contained over 3,000 entries.

To debug I checked len(set(agent.q_func.keys())) < len(list(agent.q_func.keys())) and found that this was True.

The keys are objects of a custom State class, but I checked the hashes of the seemingly-duplicated ones and they were the same as each other as you would expect, so I don't think it is that they're not hashing properly.

For reference, these are how the methods are defined within the class:

def __eq__(self, other):
    if isinstance(other, State):
        return self.data == other.data
    return False

def __hash__(self):
    if type(self.data).__module__ == np.__name__:
        # Numpy arrays
        return hash(str(self.data))
    elif self.data.__hash__ is None:
        return hash(tuple(self.data))
    else:
        return hash(self.data)

I know that dictionaries don't support duplicate keys, so what actually is going on?

EDIT: add (minimally) reproducible example

import random
from simple_rl.mdp.MDPClass import MDP
from simple_rl.mdp.StateClass import State
from simple_rl.agents import QLearningAgent
from collections import defaultdict
class Toy:
  def __init__(self):
    self.buttons = [random.randint(0, 1) for _ in range(4)]

  def __eq__(self, other):
    return isinstance(other, Toy) and \
      self.buttons == other.buttons

  def __hash__(self):
    return hash(tuple(self.buttons))
  
  def spin(self):
    steps = random.randint(0, 3)
    self.buttons = self.buttons[steps:] + self.buttons[:steps]

  def press(self, to_press):
    for button in to_press:
      self.buttons[button] = 1 - self.buttons[button]
    self.spin()

  def is_solved(self):
    return all(button == 0 for button in self.buttons)
class ToyMDP(MDP):
    ACTIONS = ["1", "2", "3", "4"]

    def __init__(self): 

        # Setup init state.
        self.toy = Toy()
        init_state = State(data=self.toy.buttons)
        
        MDP.__init__(self, ToyMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state)

    def get_parameters(self):
        param_dict = defaultdict(int)
        param_dict["num_buttons"] = 4
        return param_dict

    def _reward_func(self, state, action, next_state):
        if state.is_terminal():
            return 0
        return int(next_state.is_terminal())

    def _transition_func(self, state, action):
        # Get buttons to press from action name.
        to_press = list(range(int(action)))

        # Make new state.
        self.toy.press(to_press)
        new_state = State(self.toy.buttons)

        # Set terminal.
        if self._is_goal_state(state):
            new_state.set_terminal(True)

        return new_state

    def _is_goal_state(self, state):
        return sum(state) == 0

    def __str__(self):
        return 'toy'
def train_agent(agent, mdp, num_steps=1000, num_episodes=10):
    for i in range(num_episodes):
        state = mdp.get_init_state()
        reward = 0
        for _ in range(num_steps):
          # Compute the agent's policy.
            action = agent.act(state, reward)

          # Terminal check.
            if state.is_terminal():
                break

            # Execute in MDP.
            reward, next_state = mdp.execute_agent_action(action)

            # Update pointer.
            state = next_state
        agent.end_of_episode()
        mdp.reset()
mdp = ToyMDP()
agent = QLearningAgent(actions=mdp.get_actions())
train_agent(agent, mdp)

Then run the following to find the duplicates:

q = set()
for k in agent.q_func:
    q.add(k) # or q.add(hash(k))
print(len(q), len(agent.q_func))
selroh18
  • 5
  • 4
  • 4
    you really need to provide a [mcve] – juanpa.arrivillaga Apr 08 '23 at 07:45
  • 1
    Are you mutating your objects after adding them to the dictionary? That makes them unfindable, other than by iterating the entire dictionary. – jasonharper Apr 18 '23 at 15:37
  • *Minimal* means the smallest amount of code which demonstrates your issue. Show us two `State` objects which are equal according to you, but not equal according to the `set`/`list` length check. *Minimal* also means not depending on things that are irrelevant to your issue; since your issue is about Python language features (`__eq__`, `__hash__` and dictionaries/sets), you should be able to reproduce the issue without importing a third-party library. – kaya3 Apr 18 '23 at 15:37

1 Answers1

0

The problem: State objects' hashes change during their lifetime

When you create your State objects, their .data is the .buttons of some Toy object. Specifically, the two places where state objects are created are:

init_state = State(data=self.toy.buttons)

and:

new_state = State(self.toy.buttons)

Now, the .buttons of a Toy object is a list object, which is mutable. Crucially, that list is mutated when toy.press() is called.

So after invoking toy.press(), the value of state.data will be different for any state object that was created with a reference to that Toy's .buttons. This means that hash(state) will also be different after invoking toy.press().

As a minimal repro of this, you could do something like this:

t = Toy()
s = State(data=t.buttons)
print("Before .press()")
print(f"t.buttons: {t.buttons}")
print(f"s.data: {s.data}")
print(f"hash(s): {hash(s)}")
t.press((0,))
print("After .press()")
print(f"t.buttons: {t.buttons}")
print(f"s.data: {s.data}")
print(f"hash(s): {hash(s)}")

This will demonstrate that hash(s) changes through the State object's lifetime.

If an object's hash changes during its life, that's well established to cause issues with using the object as a dictionary key. See: What happens if an object's __hash__ changes?

Possible solutions

To fix this, you could do one or both of two things:

  1. Ensure State objects use a copy of the Toy's data

This would involve creating state objects as:

State(data=self.toy.buttons.copy())

or perhaps better, converting to a tuple (guaranteed immutable):

State(data=tuple(self.toy.buttons))

  1. Don't mutate self.buttons in the Toy class

This would involve changing the implementation of press() so that it assigns a new list (or tuple) to self.buttons rather than modifying the existing self.buttons in place.

For example:

def press(self, to_press):
  self.buttons = [(1-b) if i in to_press else b for i, b in enumerate(self.buttons)]
  self.spin()
slothrop
  • 3,218
  • 1
  • 18
  • 11