0

I followed this to restore my model and do inference on a single image, my neural network is fully convolutional, so it can handle variable width/height images. I use the following code while training. Note the images I used for training are 64 width/height and fetched by queue runner.

feed_input = tf.placeholder(tf.float32, (1, None, None, 3), name=FEED_INPUT_PLACEHOLDER_NAME)
temp_feed_input = np.ndarray(shape=(1, 1, 1, 3), dtype=float)
...
sess.run([train_op], feed_dict={is_using_feed_input:False, feed_input:temp_feed_input})

I want to do inference on a different size image, say 147x256, I use the following code to do inference on a single image, it raised exception saying feed input [1, 147, 256, 3] is not compatible with [1, ?, ?, 3].

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from fuzhi_constants import *

def inference(model_dir, ldr_image_path):
    import os
    import cv2
    import time
    import numpy as np
    import tensorflow as tf
    regression_path = os.path.join(os.path.dirname(ldr_image_path), 'regression.exr')
    with tf.Graph().as_default():
        with tf.Session() as sess:
            ldr_img = cv2.imread(ldr_image_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
            ldr_img = np.expand_dims(ldr_img, axis=0)
            ckpt = tf.train.get_checkpoint_state(model_dir)
            saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+".meta")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # start queue runner
            # see e.g. https://www.tensorflow.org/programmers_guide/reading_data
            coord = tf.train.Coordinator()
            threads = []
            for qr in sess.graph.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                             start=True))
            start_time = time.time()
            regression = sess.run('conv_out/sub:0', feed_dict={IS_USING_FEED_INPUT_PLACEHOLDER_NAME+':0':True, FEED_INPUT_PLACEHOLDER_NAME+':0':ldr_img})
            elapsed_time = time.time() - start_time
            print('elapsed time: %s seconds' % (elapsed_time))
            cv2.imwrite(regression_path, regression[0])

if __name__ == '__main__':
    inference('G:\\neural_network\\optimizer=Adadelta', 'G:\\input_ldr.png')
fuzhi
  • 41
  • 4

0 Answers0