6

I'm training a Unet model using Tensorflow. If there is a problem with any of the images I am passing to the model for training, an exception is thrown. Sometimes this can occur an hour or two into training. Is it possible to catch any such exceptions in future so that my model can continue to the next image and resume training? I've tried adding a try/catch block to the process_path function shown below, but this has no effect...

def process_path(filePath):
    # catching exceptions here has no effect
    parts = tf.strings.split(filePath, '/')
    fileName = parts[-1]
    parts = tf.strings.split(fileName, '.')
    prefix = tf.convert_to_tensor(maskDir, dtype=tf.string)
    suffix = tf.convert_to_tensor("-mask.png", dtype=tf.string)
    maskFileName = tf.strings.join((parts[-2], suffix))
    maskPath = tf.strings.join((prefix, maskFileName), separator='/')

    # load the raw data from the file as a string
    img = tf.io.read_file(filePath)
    img = decode_img(img)
    mask = tf.io.read_file(maskPath)
    oneHot = decodeMask(mask)
    img.set_shape([256, 256, 3])
    oneHot.set_shape([256, 256, 10])
    return img, oneHot

trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
batchSize = 32

allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))

trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.shuffle(1000).repeat()
trainDataSet = trainDataSet.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataSet = trainDataSet.batch(batchSize)
trainDataSet = trainDataSet.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.shuffle(1000).repeat()
validDataSet = validDataSet.map(process_path)
validDataSet = validDataSet.batch(batchSize)

imageHeight = 256
imageWidth = 256
channels = 3

inputImage = Input((imageHeight, imageWidth, channels), name='img') 
model = baseUnet.get_unet(inputImage, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

callbacks = [
    EarlyStopping(patience=5, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint(outputModel, verbose=1, save_best_only=True, save_weights_only=False)
]

BATCH_SIZE = 32
BUFFER_SIZE = 1000
EPOCHS = 20

stepsPerEpoch = int(trainSize / BATCH_SIZE)
validationSteps = int(validSize / BATCH_SIZE)

model_history = model.fit(trainDataSet, epochs=EPOCHS,
                          steps_per_epoch=stepsPerEpoch,
                          validation_steps=validationSteps,
                          validation_data=validDataSet,
                          callbacks=callbacks)

The following link shows a similar case and explains that the "Python function is only executed once to build the function graph and try and except statements will have no effect at that." Although the link shows how to iterate through the dataset and catch errors...

dataset = ...
iterator = iter(dataset)

while True:
  try:
    elem = next(iterator)
    ...
  except InvalidArgumentError:
    ...
  except StopIteration:
    break

...I'm looking for a way to catch the error during training, however. Is this possible?

CSharp
  • 1,396
  • 1
  • 18
  • 41

1 Answers1

0

You might consider using tf.data.experimental.ignore_errors function that silently drops the file that is causing the trouble

user3415910
  • 440
  • 3
  • 5
  • 19