I'm writing a function to estimate maximum batch-size that can fit into a GPU. To test my solution, I've written a function that generates a lot of random computational graphs from core Keras layers and compiles Keras models. I then pass the models into this function:
from itertools import chain
from math import log, floor
import keras.backend as K
import operator as op
from functools import reduce
from keras.models import Model
def estimate_batch_size(model: Model, available_mem: int,
scale_by: float = 5.0,
precision: int = 2) -> int:
"""
:param model: keras Model
:param available_mem: available memory in bytes
:param scale_by: scaling factor
:param precision: float precision: 2 bytes for fp16, 4 - for fp32, etc.
:return: closest 2^n to the estimated batch size
"""
num_params = sum(chain.from_iterable((
(reduce(op.mul, l.output_shape[1:]) for l in model.layers),
(K.count_params(x) for x in model.trainable_weights),
(K.count_params(x) for x in model.non_trainable_weights)
)))
max_size = int(available_mem / (precision * num_params * scale_by))
return int(2 ** floor(log(max_size, 2)))
I've added the scaling factor to get a conservative estimate and avoid "out of memory" (OOM) errors during automatic hyper-parameter optimisation. Somehow, even with a scaling factor of 4.5, I get OOMs. Here is an example model summary
______________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_252 (InputLayer) (None, 50, 45, 1) 0
_________________________________________________________________
dense_696 (Dense) (None, 50, 45, 34) 68
_________________________________________________________________
dense_699 (Dense) (None, 50, 45, 279) 9765
=================================================================
Total params: 9,833
Trainable params: 9,833
Non-trainable params: 0
_________________________________________________________________
None
Here given 8GB of VRAM and scale_by=4.5
the function returns the batch size of 1024 (716333 fp32 parameters, including intermediate place-holders, times 4.5). Yet, despite the large scaling factor, I still get OOMs. I know this method does not account for placeholders allocated for gradient calculations among other things, but I'm still bewildered by the fact, that even a scale factor of 4.5 does not yield a safe estimate. Is it possible to get a more accurate estimate?