4

I'm writing a function to estimate maximum batch-size that can fit into a GPU. To test my solution, I've written a function that generates a lot of random computational graphs from core Keras layers and compiles Keras models. I then pass the models into this function:

from itertools import chain
from math import log, floor

import keras.backend as K
import operator as op
from functools import reduce
from keras.models import Model


def estimate_batch_size(model: Model, available_mem: int,
                        scale_by: float = 5.0,
                        precision: int = 2) -> int:
    """
    :param model: keras Model
    :param available_mem: available memory in bytes
    :param scale_by: scaling factor
    :param precision: float precision: 2 bytes for fp16, 4 - for fp32, etc.
    :return: closest 2^n to the estimated batch size
    """

    num_params = sum(chain.from_iterable((
        (reduce(op.mul, l.output_shape[1:]) for l in model.layers),
        (K.count_params(x) for x in model.trainable_weights),
        (K.count_params(x) for x in model.non_trainable_weights)
    )))
    max_size = int(available_mem / (precision * num_params * scale_by))
    return int(2 ** floor(log(max_size, 2)))

I've added the scaling factor to get a conservative estimate and avoid "out of memory" (OOM) errors during automatic hyper-parameter optimisation. Somehow, even with a scaling factor of 4.5, I get OOMs. Here is an example model summary

______________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_252 (InputLayer)       (None, 50, 45, 1)         0
_________________________________________________________________
dense_696 (Dense)            (None, 50, 45, 34)        68
_________________________________________________________________
dense_699 (Dense)            (None, 50, 45, 279)       9765
=================================================================
Total params: 9,833
Trainable params: 9,833
Non-trainable params: 0
_________________________________________________________________
None

Here given 8GB of VRAM and scale_by=4.5 the function returns the batch size of 1024 (716333 fp32 parameters, including intermediate place-holders, times 4.5). Yet, despite the large scaling factor, I still get OOMs. I know this method does not account for placeholders allocated for gradient calculations among other things, but I'm still bewildered by the fact, that even a scale factor of 4.5 does not yield a safe estimate. Is it possible to get a more accurate estimate?

Eli Korvigo
  • 10,265
  • 6
  • 47
  • 73
  • 2
    It´s a great idea to do something like that. Estimating the batch_size is still one of the most annoying things when workign with deep networks. – ixeption Mar 04 '19 at 09:37
  • I tried the function with scale 1 and precision 4 (My default is `float32`), it gave out values less than half of what I could train. With precision 2 that would come closer, but still was less than what I could train on my GPU. This is not giving the theoretical limit, maybe you need to adjust how you calculate `max_size`. What's the theory you are using, any references? – Saravanabalagi Ramachandran May 23 '19 at 14:17
  • Perhaps you need to take into account the [size of tensors](https://stackoverflow.com/a/46656508/3125070)? – Saravanabalagi Ramachandran May 23 '19 at 14:24
  • @SaravanabalagiRamachandran Yes, we've already tried that. Moreover, we've tried using the built-in memory profiler in `tf.profile.profile`, and the results do not correlate with any expectations. Real memory consumptions depends heavily on the optimisations TF uses on the C++ level (for gradient calculations, FFT etc), and there seems to be no safe universal, i.e. model-agnostic, solution. – Eli Korvigo May 23 '19 at 14:57
  • @EliKorvigo, Getting an accurate estimate is hard. Because, the memory usage needs to account for intermediate computed values, not just the parameter weights. As an extreme case consider f(x) = lambda: x * x. This has 0 params but memory usage would be non-zero since x and x * x consume memory. If you further compute gradients, it will be some complicated logic for which intermediate computations remain live through the whole step, and the set of live Tensors would be dependent on the runtime and possibly non-deterministic. –  Sep 19 '19 at 23:12

0 Answers0