TL;DR: using PyTorch with Optuna with multiprocessing done with Queue(), a GPU (out of 4) can hang. Probably not a deadlock. Any ideas?
Normal version:
I am using PyTorch in combination with Optuna (a hyperparameter optimization framework; basically starts different trials for one model with different parameters, see: https://optuna.readthedocs.io/en/stable/) for my model training on a setup with 4 GPUs. Here, I've been looking for a way to distribute the workload more efficiently on the GPUs, hence I explored the multiprocessing
library.
The core of the multiprocessing code looks like following:
class GpuQueue:
def __init__(self):
self.queue = multiprocessing.Manager().Queue()
all_idxs = list(range(N_GPUS)) if N_GPUS > 0 else [None]
for idx in all_idxs:
self.queue.put(idx)
@contextmanager
def one_gpu_per_process(self):
current_idx = self.queue.get()
yield current_idx
self.queue.put(current_idx)
class Objective:
def __init__(self, gpu_queue: GpuQueue, params, signals):
self.gpu_queue = gpu_queue
# create dataset
# ...
def __call__(self, trial: optuna.Trial):
with self.gpu_queue.one_gpu_per_process() as gpu_i:
val = trainer(trial, gpu=gpu_i, ...)
return val
And in main, optuna study and optuna optimize are initiated with:
study = optuna.create_study(direction="minimize", sampler = optuna.samplers.TPESampler(seed=17)) # storage = "sqlite:///trials.db")
study.optimize(Objective(GpuQueue(), ..., n_jobs=4))
Same implementation can be found in this StackOverflow post (used as inspiration): Is there a way to pass arguments to multiple jobs in optuna?
What this code does is that every trial gets its own GPU, hence the GPU usage and distribution is better than other methods. However it happens often that a GPU is stuck and just 'shuts itself off' and does not finish the trial, hence the code actually never finishes running and that GPU is never freed.
Say, for example, that I am running 100 trials, then trial 1,2,3,4 get assigned GPUs 0,1,2,3 (not always in that order), and whenever a GPU is freed, say GPU 2, it takes on trial 5, etc. The issue is, it can happen that the trial that the GPU is assigned to 'quits' in the process and never finishes the trial, hence not taking on another trial and resulting in the run with many trials not completing.
I suspected a deadlock, but apparently Queue() is thread-safe (see: Is Python multiprocessing.Queue thread safe?).
Any clue on what can cause the hang and what I can look for?