4

I am trying to make Dataset that would provide batches of TFRecords wherein one batch there would be 2 random Records from one class and the rest from the other random classes.

OR

A Dataset of batches of where there would be 2 random Records from each class that fits into that batch.

I tried to do this with tf.data.Dataset.from_generator and with tf.data.experimental.choose_from_datasets but with no success. Do you have an idea on how to do this?

EDIT: Today i think i implemented the second variant. Here is the code i was testing it on.

def input_fn():
  partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
  partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
  partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
  l = [partial1, partial2, partial3]

  def gen(x):
    return tf.data.Dataset.range(x,x+1).repeat(2)

  dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)

  choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
  return choice

which when evaulated returns

[ 0  2 21 22]
[60 61  1  4]
[20 23 62 63]
[ 3  5 24 25]
[64 66  6  7]
[26 27 65 68]
[ 8  0 28 29]
[67 69  9  2]
[20 22 60 62]
[ 3  1 23 24]
[63 61  4  6]
[25 26 65 64]
[ 7  5 27 28]
[67 66  9  8]
[21 20 69 68]
Mous
  • 83
  • 9
  • Not sure how to do so with a Dataset.from_generator, but [this post may be helpful to you](https://stackoverflow.com/questions/38260113/implementing-contrastive-loss-and-triplet-loss-in-tensorflow/38270293#38270293) – sashimi Apr 02 '19 at 23:37
  • Possible duplicate of [Implementing contrastive loss and triplet loss in Tensorflow](https://stackoverflow.com/questions/38260113/implementing-contrastive-loss-and-triplet-loss-in-tensorflow) – sashimi Apr 02 '19 at 23:38
  • I have no issue with the loss calculation I am using the one from tf.contrib. But how to Dataset that would have some triplets. I updated the question with what I think might work. – Mous Apr 03 '19 at 13:33

2 Answers2

4

In TF 2.0 , Now can use dataset.interleave read diffence class's tfrecords, and use dataset.batch to make triplet pair :

h = FcaeRecHelper('data/ms1m_img_ann.npy', [112, 112], 128, use_softmax=False)
len(h.train_list)
img_shape = list(h.in_hw) + [3]

is_augment = True
is_normlize = False

def parser(stream: bytes):
    # parser tfrecords
    examples: dict = tf.io.parse_single_example(
        stream,
        {'img': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)})
    return tf.image.decode_jpeg(examples['img'], 3), examples['label']

def pair_parser(raw_imgs, labels):
    # imgs do same augment ~
    if is_augment:
        raw_imgs, _ = h.augment_img(raw_imgs, None)
    # normlize image
    if is_normlize:
        imgs: tf.Tensor = h.normlize_img(raw_imgs)
    else:
        imgs = tf.cast(raw_imgs, tf.float32)

    imgs.set_shape([4] + img_shape)
    labels.set_shape([4, ])
    # Note y_true shape will be [batch,3]
    return (imgs[0], imgs[1], imgs[2]), (labels[:3])

batch_size = 1
# h.train_list : ['a.tfrecords','b.tfrecords','c.tfrecords',...]
ds = (tf.data.Dataset.from_tensor_slices(h.train_list)
        .interleave(lambda x: tf.data.TFRecordDataset(x)
                    .shuffle(100)
                    .repeat(), cycle_length=-1,
                    # block_length = 2 is important
                    block_length=2,
                    num_parallel_calls=-1)
        .map(parser, -1)
        .batch(4, True)
        .map(pair_parser, -1)
        .batch(batch_size, True))

iters = iter(ds)
for i in range(20):
    imgs, labels = next(iters)
    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(imgs[0].numpy().astype('uint8')[0])
    axs[1].imshow(imgs[1].numpy().astype('uint8')[0])
    axs[2].imshow(imgs[2].numpy().astype('uint8')[0])
    plt.show()
郑启航
  • 41
  • 2
1

Ok, I figured it out. The Dataset is generated successfully and the data randomness seems to be decent. It's not an ideal solution for triplet loss as the triplets are random and not semihard.

def input_fn(self, params):
    batch_size = params['batch_size']

    assert self.data_dir, 'data_dir is required'
    shuffle = self.is_training

    dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)

    def prefetch_dataset(filename): 
      dataset = tf.data.TFRecordDataset( 
          filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
      return dataset

    datasets = []
    for glob in dirs:
      dataset = tf.data.Dataset.list_files(glob)
      dataset = dataset.apply( 
        tf.contrib.data.parallel_interleave( 
            prefetch_dataset, 
            cycle_length=FLAGS.num_files_infeed, 
            sloppy=True)) # if order is important 
      dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
      datasets.append(dataset)

    def gen(x):
      return tf.data.Dataset.range(x,x+1).repeat(2)

    choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)

    dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
        self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)

    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)

    return dataset
Mous
  • 83
  • 9