0

I am trying to create a tf.data.Dataset from a generator I wrote, and following this great answer: Split .tfrecords file into many .tfrecords files

Generator Code

def get_examples_generator(num_variants, vcf_reader):
    def generator():
        counter = 0
        for vcf_read in vcf_reader:
            is_vcf_ok = ... # checking whether this "vcf" example is ok

            if is_vcf_ok and counter < num_variants:

                counter += 1

                # features extraction ...

                # we create an example
                example = make_example(img=img, label=label) # returns a SerializedExample

                yield example
    return generator

TFRecordsWriter Usage Code

def write_sharded_tfrecords(filename, path, vcf_reader,
                            num_variants,
                            shard_len):
    assert Path(path).exists(), "path does not exist"

    generator = get_examples_generator(num_variants=num_variants,
                                       vcf_reader=vcf_reader,
                                       cfdna_bam_reader=cfdna_bam_reader)

    dataset = tf.data.Dataset.from_generator(generator,
                                             output_types=tf.string,
                                             output_shapes=())

    num_shards = int(np.ceil(num_variants/shard_len))
    formatter = lambda batch_idx: f'{path}/{filename}-{batch_idx:05d}-of-' \
                                  f'{num_shards:05d}.tfrecord'
    # inspired by https://stackoverflow.com/questions/54519309/split-tfrecords-file-into-many-tfrecords-files
    for i in range(num_shards):
        shard_path = formatter(i)
        writer = tf.data.experimental.TFRecordWriter(shard_path)
        shard = dataset.shard(num_shards, index=i)
        writer.write(shard)

This is supposed to be a straight-forward use of tfrecords writer. However, It does not write any files at all. Does anyone understand why this doesn't work?

yonatansc97
  • 584
  • 6
  • 16

1 Answers1

0

In my functions, I call the writer with tf.io.TFRecordWriter. Try changing your writer and see if it works:

writer = tf.io.TFRecordWriter
...

As a further reference, this answer helped me:

https://stackoverflow.com/a/60283571

emil
  • 194
  • 1
  • 11