I'm working on a reinforcement learning project where the agent is a load balancer observes service request and servers' status. The agent is supposed to do some batch train after accumulating some observation/action(allocating request to service server)/reward(whether the request is well handled i.e. timely/correctly).
I don't want the agent down while doing the batch training.
So I try to use python's multiprocessing to initiate another process to train passing current model's weights and replacing agent's model after training.
The problem is, the training process hangs on predict call on that copied model(new instance with same weights from agent's model). I'm calling predict() on the new model for reinforcement learning purpose. fit() is the main purpose of the training process. But, anyway it's stuck on predict().
I tried some googling and found some posts saying keras is showing some weird behavior when used with multiprocessing, but I couldn't find a solution working for me.
I also tried more general search query like the title my of question. Surprisingly, there's even less information about that.
Below are the codes with some abstraction
def batch_train(self, weights, ret):
model = self.build_model()
model.set_weights(weights)
batch = self.memory[:self.batch_size]
with threading.Lock():
self.memory = self.memory[self.batch_size:]
X = []
Y = []
for state, action, next_state, reward in batch:
print(state, action, next_state, reward)
print(model)
print(model.predict(next_state)[0])
print('2')
reward = reward + self.discount_factor * np.amax([0])
target = model.predict(state)[0]
print('3')
target[action] = reward
X.append(state)
Y.append(target)
X = np.array(X)
Y = np.array(Y)
model.fit(X, Y, epochs=1)# , verbose=0)
print('finish training...')
ret.put(model)
def run and train():
model = self.build_model()
print(model)
p_train = None
ret = Queue()
for t in range(self.batch_size):
time.sleep(self.observation_interval)
state = get_state()
action = get_action()
reward = get_reward()
with threading.Lock():
self.memory.append((prev_state, prev_action, state, reward))
prev_state = state
prev_action = action
# start a thread to batch training a new model
# replace current model with new model upon completion
if p_train is not None:
model = ret.get()
print('waiting for join')
p_train.join()
print('calling batch_train...')
p_train = Process(target=self.batch_train, args=(model.get_weights(), ret))
p_train.start()