I presume you mean image_dataset_from_directory
since you are loading images and not text data. Either way, you cannot produce batches with multiple inputs from these helper functions, you can see from the documentation that the return shape is defined:
A tf.data.Dataset object.
- If label_mode is None, it yields float32 tensors of shape (batch_size, image_size[0], image_size[1], num_channels), encoding images (see below for rules regarding num_channels).
- Otherwise, it yields a tuple (images, labels), where images has shape (batch_size, image_size[0], image_size[1], num_channels), and labels follows the format described below.
You will instead need to write your own custom generator function that yields multiple inputs loaded from your data directory, and then call fit
with your custom generator and passing the kwarg validation_data
a separate generator that generates validation data. (Note: in some older versions of Keras you may need fit_generator
instead of fit
).
Here's an example of a module of some helper functions that can read images from some directories and present them as multi-image inputs in training.
def _generate_batch(training):
in1s, in2s, labels = [], [], []
batch_tuples = _sample_batch_of_paths(training)
for input1_path, input2_path in batch_tuples:
# skip any exception so that image GPU batch loading isn't
# disrupted and any faulty image is just skipped.
try:
in1_tmp = _load_image(
os.path.join(INPUT1_PATH_PREFIX, input1_path),
)
in2_tmp = _load_image(
os.path.join(INPUT2_PATH_PREFIX, input2_path),
)
except Exception as exc:
print("Unhandled exception during image batch load. Skipping...")
print(str(exc))
continue
# if no exception, both images loaded so both are added to batch.
in1s.append(in1_tmp)
in2s.append(in2_tmp)
# Whatever your custom logic is to determine the label for the pair.
labels.append(
_label_calculation_helper(input1_path, input2_path)
)
in1s, in2s = map(skimage.io.concatenate_images, [in1s, in2s])
# could also add a singleton channel dimension for grayscale images.
# in1s = in1s[:, :, :, None]
return [in1s, in2s], labels
def _make_generator(training=True):
while True:
yield _generate_batch(training)
def make_generators():
return _make_generator(training=True), _make_generator(training=False)
The helper _load_image
could be something like this:
def _load_image(path, is_gray=False):
tmp = skimage.io.imread(path)
if is_gray:
tmp = skimage.util.img_as_float(skimage.color.rgb2gray(tmp))
else:
tmp = skimage.util.img_as_float(skimage.color.gray2rgb(tmp))
if tmp.shape[-1] == 4:
tmp = skimage.color.rgba2rgb(tmp)
# Do other stuff here - resizing, clipping, etc.
return tmp
and the helper function to sample a batch from a set of paths listed off disk could be like this:
@lru_cache(1)
def _load_and_split_input_paths():
training_in1s, testing_in1s = train_test_split(
os.listdir(INPUT1_PATH_PREFIX),
test_size=TEST_SIZE,
random_state=RANDOM_SEED
)
training_in2s, testing_in2s = train_test_split(
os.listdir(INPUT2_PATH_PREFIX),
test_size=TEST_SIZE,
random_state=RANDOM_SEED
)
return training_in1s, testing_in1s, training_in2s, testing_in2s
def _sample_batch_of_paths(training):
training_in1s, testing_in1s, training_in2s, testing_in2s = _load_and_split_input_paths()
if training:
return list(zip(
random.sample(training_in1s, BATCH_SIZE),
random.sample(training_in2s, BATCH_SIZE)
))
else:
return list(zip(
random.sample(testing_in1s, BATCH_SIZE),
random.sample(testing_in2s, BATCH_SIZE)
))
This would randomly sample images from some "input 1" directory and pair them with random samples from an "input 2" directory. Obviously in your use case you'll want to change this so that the data are pulled deterministically according to the file structure that defines their pairings and labelings.
Finally once you want to use this, you can call training code such as:
training_generator, testing_generator = make_generators()
try:
some_compiled_model.fit(
training_generator,
epochs=EPOCHS,
validation_data=testing_generator,
callbacks=[...],
verbose=VERBOSE,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
)
except KeyboardInterrupt:
pass