1

I have a process that reads URIs of CSV files located in cloud storage, serializes the data (one file is an "example" in tensorflow speak), and writes them to the same TFRecord file.

The process is very slow and I would like to parallelize the writing using python multiprocessing. I've searched high and low and tried multiple implementations to no avail. This question is very similar to mine, but the question is never really answered.

This is the closest I've come (unfortunately, I can't really provide a replicable example due to the read from cloud storage):

import pandas as pd
import multiprocessing
import tensorflow as TF


TFR_PATH = "./tfr.tfrecord"
BANDS = ["B2", "B3","B4","B5","B6","B7","B8","B8A","B11","B12"]

def write_tfrecord(tfr_path, df_list, bands):
    with tf.io.TFRecordWriter(tfr_path) as writer: 
        for _, grp in df_list:
            band_data = {b: [] for b in bands}
            for i, row in grp.iterrows():
                try:
                    df = pd.read_csv(row['uri'])
                except FileNotFoundError:
                    continue
                df = prepare_df(df, bands)
                label = row['FS_crop'].encode()
                for b in bands:
                    band_data[b].append(list(df[b].astype('Int64')))

            # pad to same length and flatten
            mlen = max([len(j) for j in band_data[list(band_data.keys())[0]]])
            npx = len(band_data[list(band_data.keys())[0]])
            flat_band_data = {k: [] for k in band_data}
            for k,v in band_data.items(): # for each band
                for b in v:
                    flat_band_data[k].extend(b + [0] * int(mlen - len(b)))

            example_proto = serialize_example(npx, flat_band_data, label)
            writer.write(example_proto)

# List of grouped DF object, may be 1000's long
gqdf = list(qdf.groupby("field_centroid_str"))

n = 100 #Groups of files to write 
processes = [multiprocessing.Process(target=write_tfrecord, args=(TFR_PATH, gqdf[i:i+n], BANDS)) for i in range(0, len(gqdf), n)]

for p in processes:
    p.start()

for p in processes:
    p.join()

p.close()

This processes will finish, but when I go to read a record, I like so:

raw_dataset = tf.data.TFRecordDataset(TFR_PATH)
for raw_record in raw_dataset.take(10):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    print(example)

I always end up with a corrupted data error DataLossError: corrupted record at 7462 [Op:IteratorGetNext]

Any ideas on the correct approach for doing something like this? I've tried using Pool instead of Process, but the tf.io.TFRecordWriter can't be pickled, so it doesn't work.

JmeCS
  • 497
  • 4
  • 17

1 Answers1

0

Ran into a similar use case. The core issue is that the record writer isn't process-safe. There are two bottlenecks - serializing the data and writing to the output. My solution here was to use multiprocessing (ex. a pool) to serialize the data in parallel. Each worker communicates the serialized data to a single consumer process using a queue. The consumer simply pulls off the queue and writes sequentially. If this is now the bottleneck, you could have multiple record writers write to different files.

bycn
  • 1
  • 1