I'm training a model for image segmentation using tf.keras using a custom data generator to read and augment images. While training the model works fine (i.e. without memory problems), when trying to predict on my test set my GPU (8GB, see nvidia-smi later) runs out of memory. This is the case both when predicting directly after training and after restarting the kernel, loading the model using model.load_weights()
and using model.predict()
afterwards and with the same batch size used in training (4, using ~6GB of memory during training) or a batch size of 1 with both batch sizes trying to allocate more than 8GB.
During training, the memory usage is stable around 6GB but when using model.predict()
it starts out at around 6GB but jumps to 8GB after approximately 10 seconds before throwing the ResourceExhaustedError
(see later for stacktrace). This seems very counter-intuitive to me and the tips I've found through google (e.g. restarting python, loading the model from weights, then predict to free memory used beforehand) haven't worked, so any help would be great.
The output of !nvidia-smi
, my code for the data generator and training/prediction including error messages are as follows:
nvidia-smi
Mon Aug 9 14:27:29 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 471.11 Driver Version: 471.11 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... WDDM | 00000000:0D:00.0 On | N/A |
| 56% 50C P8 24W / 220W | 8057MiB / 8192MiB | 4% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1576 C+G Insufficient Permissions N/A |
| 0 N/A N/A 2292 C+G ...kyb3d8bbwe\Calculator.exe N/A |
| 0 N/A N/A 8316 C+G C:\Windows\explorer.exe N/A |
| 0 N/A N/A 8736 C+G ...lPanel\SystemSettings.exe N/A |
| 0 N/A N/A 11220 C+G ...bbwe\Microsoft.Photos.exe N/A |
| 0 N/A N/A 11740 C+G ...5n1h2txyewy\SearchApp.exe N/A |
| 0 N/A N/A 12280 C+G ...ekyb3d8bbwe\YourPhone.exe N/A |
| 0 N/A N/A 12820 C+G ...8wekyb3d8bbwe\GameBar.exe N/A |
| 0 N/A N/A 13820 C+G ...perience\NVIDIA Share.exe N/A |
| 0 N/A N/A 14552 C+G ...nputApp\TextInputHost.exe N/A |
| 0 N/A N/A 14848 C+G ...y\ShellExperienceHost.exe N/A |
| 0 N/A N/A 14976 C+G ...zilla Firefox\firefox.exe N/A |
| 0 N/A N/A 15688 C+G ...udibleRT.WindowsPhone.exe N/A |
| 0 N/A N/A 16628 C ...Data\Anaconda3\python.exe N/A |
| 0 N/A N/A 23648 C+G ...aming\Spotify\Spotify.exe N/A |
+-----------------------------------------------------------------------------+
Data Generator
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, df, batch_size, mode="train", shuffle=True, augment=False, p_augment=0,
union=False, greyscale=False, normalize=True, dims=(256, 1600)):
"""DataGenerator usable for train/val/test splits"""
self.df = df
self.length = len(df)
self.BATCH_SIZE = batch_size
self.mode = mode
self.shuffle = shuffle
self.augment = augment
self.p_augment = p_augment
self.union = union
self.greyscale = greyscale
self.normalize = normalize
self.dims = dims
self.num_channels = 1 if greyscale else 3
self.num_classes = 1 if union else 4
self.indices = df.index.values.tolist() # will be reset anyways
self.on_epoch_end()
assert mode in ["train", "predict"], "DataGenerator mode is unsupported. Set it to \"train\" or \"predict\"."
if augment:
assert p_augment > 0 and p_augment <= 1, "Augmentation is turned on, but probability is zero or larger than one."
def __len__(self):
"""number of batches in each epoch"""
return int(np.floor(self.length / self.BATCH_SIZE))
def on_epoch_end(self):
"""shuffle list of indices"""
# called on the end of every epoch
self.indices = self.df.index.values.tolist()
if self.shuffle:
np.random.shuffle(self.indices)
def _load_img(self, img_path):
"""loads image in RGB/greyscale and normalizes it"""
if self.greyscale:
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
else:
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.normalize:
img = img.astype(np.float32) / 255.
else:
img = img.astype(np.float32)
return img
def _gen_x(self, idx_list):
"""generates input values from list of indices"""
res = np.empty(shape=(self.BATCH_SIZE, *self.dims, self.num_channels))
for i, df_idx in enumerate(idx_list):
img_path = self.df.loc[df_idx]["img_id"]
img = self._load_img(img_path)
if self.greyscale:
res[i, ] = np.expand_dims(img, axis=-1)
else:
res[i, ] = img
return res
def _gen_tgt(self, idx_list):
"""generates target values from list of indices"""
res = np.empty(shape=(self.BATCH_SIZE, *self.dims, self.num_classes))
for i, df_idx in enumerate(idx_list):
rles = self.df.loc[df_idx]["c1":"c_all"]
if self.union:
# return mask of all defect pixels (no diff between defect class)
masks = build_masks(rles, union_only=True)
else:
masks = build_masks(rles)
res[i, ] = masks
return res
def __getitem__(self, idx):
"""creates one batch of data"""
# get indices of batch (self.indices is shuffled list of df indices)
idxs = self.indices[idx*self.BATCH_SIZE:(idx+1)*self.BATCH_SIZE]
x = self._gen_x(idxs)
if self.mode == "predict":
return x
# mode is train -> get target data and possible augment
tgt = self._gen_tgt(idxs)
if self.augment:
x, tgt = self._augment_batch(x, tgt)
return x, tgt
def _augment_batch(self, _x, _tgt):
# flips img and masks vertically and/or horizontally with p_augment respectively
for i in range(self.BATCH_SIZE):
# flip up-down
if random.random() > self.p_augment:
if self.greyscale:
_x[i] = np.expand_dims(cv2.flip(_x[i], flipCode=0), axis=-1)
else:
_x[i] = cv2.flip(_x[i], flipCode=0)
_tgt[i] = cv2.flip(_tgt[i], flipCode=0)
# flip left-right
if random.random() > self.p_augment:
if self.greyscale:
_x[i] = np.expand_dims(cv2.flip(_x[i], flipCode=1), axis=-1)
else:
_x[i] = cv2.flip(_x[i], flipCode=1)
_tgt[i] = cv2.flip(_tgt[i], flipCode=1)
return _x, _tgt
Training
from copy import deepcopy
# configs for train/val datagens
train_config = {"mode": "train",
"batch_size": 4,
"shuffle":True,
"augment":True,
"p_augment": 0.5,
"union": False,
"greyscale": False,
"normalize": True,
"dims": (256,1600)}
val_config = deepcopy(train_config)
val_config["shuffle"] = False
val_config["augment"] = False
train_datagen = DataGenerator(df_train, **train_config)
val_datagen = DataGenerator(df_val, **val_config)
# returns model with correct image dims and number of classes
model = get_model_from_generator(train_datagen)
model.compile(optimizer=Adam(learning_rate=1e-4), loss=bce_dice_loss,
metrics=["binary_crossentropy", dice_coef])
cb_es = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
cb_best = tf.keras.callbacks.ModelCheckpoint("models/fcn_rgb/cp_{epoch:02d}_{val_loss:.3f}.ckpt", monitor="val_loss",
save_weights_only=True, save_best_only=True)
history = model.fit(x=train_datagen, callbacks=[cb_es, cb_best], epochs=100,
validation_data=val_datagen)
which trains fine (see output of first epochs)
Epoch 1/100
2513/2513 [==============================] - 541s 207ms/step - loss: 0.2413 - binary_crossentropy: 0.5235 - dice_coef: 0.0205 - val_loss: -0.0034 - val_binary_crossentropy: 0.1481 - val_dice_coef: 0.0775
Epoch 2/100
2513/2513 [==============================] - 518s 206ms/step - loss: -0.1231 - binary_crossentropy: 0.0864 - dice_coef: 0.1663 - val_loss: -0.2862 - val_binary_crossentropy: 0.0627 - val_dice_coef: 0.3175
Prediction
Throws the same error with or without restarting the kernel after training the model.
test_config = {"mode": "predict",
"batch_size": 1,
"shuffle":False,
"augment":False,
"p_augment": 0,
"union": False,
"greyscale": False,
"normalize": True,
"dims": (256,1600)}
test_datagen = DataGenerator(df_test, **test_config)
model = get_model_from_generator(train_datagen)
model.load_weights("models/fcn_rgb/cp_38_-0.497.ckpt")
preds = model.predict(test_datagen)
error message:
---------------------------------------------------------------------------
ResourceExhaustedError Traceback (most recent call last)
<ipython-input-16-d0a77a4d2cd0> in <module>
----> 1 preds = model.predict(test_datagen)
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\engine\training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
1627 for step in data_handler.steps():
1628 callbacks.on_predict_batch_begin(step)
-> 1629 tmp_batch_outputs = self.predict_function(iterator)
1630 if data_handler.should_sync:
1631 context.async_wait()
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
860 # In this case we have not created variables on the first call. So we can
861 # run the first trace but we should fail if variables are created.
--> 862 results = self._stateful_fn(*args, **kwds)
863 if self._created_variables:
864 raise ValueError("Creating variables on a non-first call to a function"
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2940 (graph_function,
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 2942 return graph_function._call_flat(
2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1916 and executing_eagerly):
1917 # No tape is watching; skip to running the function.
-> 1918 return self._build_call_outputs(self._inference_function.call(
1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
553 with _InterpolateFunctionError(self):
554 if cancellation_manager is None:
--> 555 outputs = execute.execute(
556 str(self.signature.name),
557 num_outputs=self._num_outputs,
~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
ResourceExhaustedError: OOM when allocating tensor with shape[1,96,128,800] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[node model_1/batch_normalization_57/FusedBatchNormV3 (defined at <ipython-input-16-d0a77a4d2cd0>:1) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_predict_function_8858]
Function call stack:
predict_function
I'm using tensorflow version 2.4.1.
edit: I forgot to mention I've also tried activation tfs dynamic memory allocation before both training and predicting with the following code, but the error still appears.
# dynamic memory allocation
gpus = tf.config.list_physical_devices("GPU")
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# must be set before GPUs have been initialized
print(e)