I'm implementing a version of AlphaZero (AlphaGo's most recent incarnation) to be applied to some other domain.
The crux of the algorithm is a Monte Carlo Tree Search of the state space (CPU) interleaved with 'intuition' (probabilities) from a neural network in eval mode (GPU). The MCTS result is then used to train the neural network.
I already parallelized the CPU execution by launching multiple processes which each build up their own tree. This is effective and has now lead to a GPU bottleneck! (nvidia-smi showing the GPU at 100% all the time)
I have devised 2 strategies to parallelize GPU evaluations, however both of them have problems.
Each process evaluates the network only on batches from its own tree. In my initial naive implementation, this meant a batch size of 1. However, by refactoring some code and adding a 'virtual loss' to discourage (but not completely block) the same node from being picked twice we can get larger batches of size 1-4. The problem here is that we cannot allow large delays until we evaluate the batch or accuracy suffers, so a small batch size is key here.
Send the batches to a central "neural network worker" thread which combines and evaluates them. This could be done in a large batch of 32 or more, so the GPU could be used very efficiently. The problem here is that the tree workers send CUDA tensors 'round-trip' which is not supported by PyTorch. It is supported if I clone them first, but all that constant copying makes this approach slower than the first one.
I was thinking maybe a clever batching scheme that I'm not seeing could make the first approach work. Using multiple GPUs could speed up the first approach too, but the kind of parallelism I want is not natively supported by PyTorch. Maybe keeping all tensors in the NN worker and only sending ids around could improve the second approach, however the difficulty here is how to synchronize effectively to get a large batch without making the CPU threads wait too long.
I found next to no information on how AlphaZero or AlphaGo Zero were parallelized in their respective papers. I was able to find limited information online however which lead me to improve the first approach.
I would be grateful for any advice on this, particularly if there's some point or approach I missed.