I want to catch the runtime error CUDA out of memory
on multiple occasions in my code. I do this to then rerun the whole training workflow with lower batch size. What is the best way to do that?
I am currently doing this:
try:
result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
if str(e).startswith('CUDA out of memory.') and batch_size > 10:
raise CudaOutOfMemory(e)
else:
raise e
I then catch the error CudaOutOfMemory
outside my main function.
However, this is a pretty long piece of code that I need to repeat many times. Is there any way to make a context manager for this?
such that instead I can run:
with catch_cuda_out_of_mem_error:
result = model(input)
Edit: I want to create a context manager instead of a function because the functions I want to wrap the "try, except" around are not always the same. In my workflow, I have many functions that use a lot of GPU memory and I would like to catch this error in any of them.