Getting:
assert q_values.shape == (len(state_batch), self.nb_actions)
AssertionError
q_values.shape <class 'tuple'>: (1, 1, 10)
(len(state_batch), self.nb_actions) <class 'tuple'>: (1, 10)
which is from the keras-rl library of the sarsa agent:
rl.agents.sarsa.SARSAAgent#compute_batch_q_values
batch = self.process_state_batch(state_batch)
q_values = self.model.predict_on_batch(batch)
assert q_values.shape == (len(state_batch), self.nb_actions)
Here is my code:
class MyEnv(Env):
def __init__(self):
self._reset()
def _reset(self) -> None:
self.i = 0
def _get_obs(self) -> List[float]:
return [1] * 20
def reset(self) -> List[float]:
self._reset()
return self._get_obs()
model = Sequential()
model.add(Dense(units=20, activation='relu', input_shape=(1, 20)))
model.add(Dense(units=10, activation='softmax'))
logger.info(model.summary())
policy = BoltzmannQPolicy()
agent = SARSAAgent(model=model, nb_actions=10, policy=policy)
optimizer = Adam(lr=1e-3)
agent.compile(optimizer, metrics=['mae'])
env = MyEnv()
agent.fit(env, 1, verbose=2, visualize=True)
Was wondering if someone can explain to me how the dimensions should be set up and how it works with the libraries? I'm putting in a list of 20 inputs, and want an output of 10.