I have a process that reads URIs of CSV files located in cloud storage, serializes the data (one file is an "example" in tensorflow
speak), and writes them to the same TFRecord file.
The process is very slow and I would like to parallelize the writing using python multiprocessing. I've searched high and low and tried multiple implementations to no avail. This question is very similar to mine, but the question is never really answered.
This is the closest I've come (unfortunately, I can't really provide a replicable example due to the read from cloud storage):
import pandas as pd
import multiprocessing
import tensorflow as TF
TFR_PATH = "./tfr.tfrecord"
BANDS = ["B2", "B3","B4","B5","B6","B7","B8","B8A","B11","B12"]
def write_tfrecord(tfr_path, df_list, bands):
with tf.io.TFRecordWriter(tfr_path) as writer:
for _, grp in df_list:
band_data = {b: [] for b in bands}
for i, row in grp.iterrows():
try:
df = pd.read_csv(row['uri'])
except FileNotFoundError:
continue
df = prepare_df(df, bands)
label = row['FS_crop'].encode()
for b in bands:
band_data[b].append(list(df[b].astype('Int64')))
# pad to same length and flatten
mlen = max([len(j) for j in band_data[list(band_data.keys())[0]]])
npx = len(band_data[list(band_data.keys())[0]])
flat_band_data = {k: [] for k in band_data}
for k,v in band_data.items(): # for each band
for b in v:
flat_band_data[k].extend(b + [0] * int(mlen - len(b)))
example_proto = serialize_example(npx, flat_band_data, label)
writer.write(example_proto)
# List of grouped DF object, may be 1000's long
gqdf = list(qdf.groupby("field_centroid_str"))
n = 100 #Groups of files to write
processes = [multiprocessing.Process(target=write_tfrecord, args=(TFR_PATH, gqdf[i:i+n], BANDS)) for i in range(0, len(gqdf), n)]
for p in processes:
p.start()
for p in processes:
p.join()
p.close()
This processes will finish, but when I go to read a record, I like so:
raw_dataset = tf.data.TFRecordDataset(TFR_PATH)
for raw_record in raw_dataset.take(10):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
I always end up with a corrupted data error DataLossError: corrupted record at 7462 [Op:IteratorGetNext]
Any ideas on the correct approach for doing something like this? I've tried using Pool
instead of Process
, but the tf.io.TFRecordWriter
can't be pickled, so it doesn't work.