I'm training a Unet model using Tensorflow. If there is a problem with any of the images I am passing to the model for training, an exception is thrown. Sometimes this can occur an hour or two into training. Is it possible to catch any such exceptions in future so that my model can continue to the next image and resume training? I've tried adding a try/catch
block to the process_path
function shown below, but this has no effect...
def process_path(filePath):
# catching exceptions here has no effect
parts = tf.strings.split(filePath, '/')
fileName = parts[-1]
parts = tf.strings.split(fileName, '.')
prefix = tf.convert_to_tensor(maskDir, dtype=tf.string)
suffix = tf.convert_to_tensor("-mask.png", dtype=tf.string)
maskFileName = tf.strings.join((parts[-2], suffix))
maskPath = tf.strings.join((prefix, maskFileName), separator='/')
# load the raw data from the file as a string
img = tf.io.read_file(filePath)
img = decode_img(img)
mask = tf.io.read_file(maskPath)
oneHot = decodeMask(mask)
img.set_shape([256, 256, 3])
oneHot.set_shape([256, 256, 10])
return img, oneHot
trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
batchSize = 32
allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))
trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.shuffle(1000).repeat()
trainDataSet = trainDataSet.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataSet = trainDataSet.batch(batchSize)
trainDataSet = trainDataSet.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.shuffle(1000).repeat()
validDataSet = validDataSet.map(process_path)
validDataSet = validDataSet.batch(batchSize)
imageHeight = 256
imageWidth = 256
channels = 3
inputImage = Input((imageHeight, imageWidth, channels), name='img')
model = baseUnet.get_unet(inputImage, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
callbacks = [
EarlyStopping(patience=5, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
ModelCheckpoint(outputModel, verbose=1, save_best_only=True, save_weights_only=False)
]
BATCH_SIZE = 32
BUFFER_SIZE = 1000
EPOCHS = 20
stepsPerEpoch = int(trainSize / BATCH_SIZE)
validationSteps = int(validSize / BATCH_SIZE)
model_history = model.fit(trainDataSet, epochs=EPOCHS,
steps_per_epoch=stepsPerEpoch,
validation_steps=validationSteps,
validation_data=validDataSet,
callbacks=callbacks)
The following link shows a similar case and explains that the "Python function is only executed once to build the function graph and try and except statements will have no effect at that." Although the link shows how to iterate through the dataset and catch errors...
dataset = ...
iterator = iter(dataset)
while True:
try:
elem = next(iterator)
...
except InvalidArgumentError:
...
except StopIteration:
break
...I'm looking for a way to catch the error during training, however. Is this possible?