1

I'm working on a reinforcement learning project where the agent is a load balancer observes service request and servers' status. The agent is supposed to do some batch train after accumulating some observation/action(allocating request to service server)/reward(whether the request is well handled i.e. timely/correctly).

I don't want the agent down while doing the batch training.

So I try to use python's multiprocessing to initiate another process to train passing current model's weights and replacing agent's model after training.

The problem is, the training process hangs on predict call on that copied model(new instance with same weights from agent's model). I'm calling predict() on the new model for reinforcement learning purpose. fit() is the main purpose of the training process. But, anyway it's stuck on predict().

I tried some googling and found some posts saying keras is showing some weird behavior when used with multiprocessing, but I couldn't find a solution working for me.

I also tried more general search query like the title my of question. Surprisingly, there's even less information about that.

Below are the codes with some abstraction

def batch_train(self, weights, ret):
    model = self.build_model()
    model.set_weights(weights)
    batch = self.memory[:self.batch_size]
    with threading.Lock():
        self.memory = self.memory[self.batch_size:]

    X = []
    Y = []
    for state, action, next_state, reward in batch:
        print(state, action, next_state, reward)
        print(model)
        print(model.predict(next_state)[0])
        print('2')
        reward = reward + self.discount_factor * np.amax([0])

        target = model.predict(state)[0]
        print('3')
        target[action] = reward
        X.append(state)
        Y.append(target)
    X = np.array(X)
    Y = np.array(Y)
    model.fit(X, Y, epochs=1)# , verbose=0)
    print('finish training...')
    ret.put(model)

def run and train():
    model = self.build_model()
    print(model)
    p_train = None
    ret = Queue()
    for t in range(self.batch_size):
        time.sleep(self.observation_interval)
        state = get_state()
        action = get_action()
        reward = get_reward()
        with threading.Lock():
            self.memory.append((prev_state, prev_action, state, reward))
        prev_state = state
        prev_action = action

    # start a thread to batch training a new model
    # replace current model with new model upon completion
    if p_train is not None:
        model = ret.get()
        print('waiting for join')
        p_train.join()
    print('calling batch_train...')
    p_train = Process(target=self.batch_train, args=(model.get_weights(), ret))
    p_train.start()
jiyolla
  • 27
  • 6
  • That threaidng.Lock() also seems to have some problem, but anyway currently it's not the criminal – jiyolla May 29 '21 at 15:47

1 Answers1

0

Don't know why, but importing keras separately in each process solves the issue.

It's mentioned here. How can take advantage of multiprocessing and multithreading in Deep learning using Keras?

So I just restructured my code to run train process and serve process in parallel from the beginning communicating with Pipe().

jiyolla
  • 27
  • 6