0

I want to catch the runtime error CUDA out of memory on multiple occasions in my code. I do this to then rerun the whole training workflow with lower batch size. What is the best way to do that?

I am currently doing this:

try:
    result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
    if str(e).startswith('CUDA out of memory.') and batch_size > 10:
        raise CudaOutOfMemory(e)
    else:
        raise e

I then catch the error CudaOutOfMemory outside my main function.

However, this is a pretty long piece of code that I need to repeat many times. Is there any way to make a context manager for this?

such that instead I can run:

with catch_cuda_out_of_mem_error:
  result = model(input)

Edit: I want to create a context manager instead of a function because the functions I want to wrap the "try, except" around are not always the same. In my workflow, I have many functions that use a lot of GPU memory and I would like to catch this error in any of them.

Amir Afianian
  • 2,679
  • 4
  • 22
  • 46
Jimmy2027
  • 313
  • 3
  • 10

2 Answers2

1

Using a context manager is about properly acquiring and releasing a resource. Here you don't really have any resource that you are acquiring and releasing, so I don't think a context manager is appropriate. How about just using a function?

def try_compute_model(input):
    try:
        return model(input)
    # if the GPU runs out of memory, start the experiment again with a smaller batch size
    except RuntimeError as e:
        if str(e).startswith('CUDA out of memory.') and batch_size > 10:
            raise CudaOutOfMemory(e)
        else:
            raise e

Then use it like

result = try_compute_model(input)
mCoding
  • 4,059
  • 1
  • 5
  • 11
  • The Python developers are not so strict. E.g. the `contexlib.suppress` manager does not acquire/release resources. – VPfB Nov 18 '20 at 20:48
  • Hi @mCoding thanks for your response! I want to create a contextmanager instead of a function because the function I want to wrap the "try, except" around is not always the same. In my workflow I have many functions that use a lot of GPU memory and I would like to catch this error in any of them. Sorry if my question was misleading in that sense – Jimmy2027 Nov 18 '20 at 21:08
0

Inspired by this post: General decorator to wrap try except in python? I found an answer to my problem:

import torch
from contextlib import contextmanager


class CudaOutOfMemory(Exception):
    pass


@contextmanager
def catching_cuda_out_of_memory():
    """
    Context that throws CudaOutOfMemory error if GPU is out of memory.
    """
    try:
        yield
    except RuntimeError as e:
        if str(e).startswith('CUDA out of memory.'):
            raise CudaOutOfMemory(e)
        else:
            raise e


def oom():
    x = torch.randn(100, 10000, device=1)
    for _ in range(100):
        l = torch.nn.Linear(10000, 10000)
        l.to(1)
        x = l(x)


try:
    with catching_cuda_out_of_memory():
        oom()
except CudaOutOfMemory:
    print('GOTCHA!')
Jimmy2027
  • 313
  • 3
  • 10