0

I want to use CNN to solve the deblurring task, and I have training data that is a directory of png images and a corresponding text file containing the files name.

As the data is too large to add to the memory with one step, and is there any API or some method to make it possible that I could read the blury image as input and its ground-truth as expected result to train?

I have spent quite a few time to solve this, but I got confused after read the APIs in the online API introductions.

wanger
  • 145
  • 1
  • 6

1 Answers1

0

The method is not that confused. The tensorflow provides TFrecords file to make good use of the memory.

def create_cord():

    writer = tf.python_io.TFRecordWriter("train.tfrecords")
    for index in xrange(66742):
        blur_file_name = get_file_name(index, True)
        orig_file_name = get_file_name(index, False)
        blur_image_path = cwd + blur_file_name
        orig_image_path = cwd + orig_file_name

        blur_image = Image.open(blur_image_path)
        orig_image = Image.open(orig_image_path)

        blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
        orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))

        blur_image_raw = blur_image.tobytes()
        orig_image_raw = orig_image.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
        "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
        'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
    }))
    writer.write(example.SerializeToString())
    writer.close()

to read the dataset:

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                   features={
                                       'blur_image_raw':    tf.FixedLenFeature([], tf.string),
                                       'orig_image_raw': tf.FixedLenFeature([], tf.string),
                                   })

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5

    return blur_img, orig_img


if __name__ == '__main__':

    #  create_cord()

    blur, orig = read_and_decode("train.tfrecords")
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
                                                batch_size=3, capacity=1000,
                                                min_after_dequeue=100)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
     # 启动队列
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(3):
            v, l = sess.run([blur_batch, orig_batch])
            print(v.shape, l.shape)
wanger
  • 145
  • 1
  • 6