I followed this to restore my model and do inference on a single image, my neural network is fully convolutional, so it can handle variable width/height images. I use the following code while training. Note the images I used for training are 64 width/height and fetched by queue runner.
feed_input = tf.placeholder(tf.float32, (1, None, None, 3), name=FEED_INPUT_PLACEHOLDER_NAME)
temp_feed_input = np.ndarray(shape=(1, 1, 1, 3), dtype=float)
...
sess.run([train_op], feed_dict={is_using_feed_input:False, feed_input:temp_feed_input})
I want to do inference on a different size image, say 147x256, I use the following code to do inference on a single image, it raised exception saying feed input [1, 147, 256, 3] is not compatible with [1, ?, ?, 3].
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from fuzhi_constants import *
def inference(model_dir, ldr_image_path):
import os
import cv2
import time
import numpy as np
import tensorflow as tf
regression_path = os.path.join(os.path.dirname(ldr_image_path), 'regression.exr')
with tf.Graph().as_default():
with tf.Session() as sess:
ldr_img = cv2.imread(ldr_image_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
ldr_img = np.expand_dims(ldr_img, axis=0)
ckpt = tf.train.get_checkpoint_state(model_dir)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+".meta")
saver.restore(sess, ckpt.model_checkpoint_path)
# start queue runner
# see e.g. https://www.tensorflow.org/programmers_guide/reading_data
coord = tf.train.Coordinator()
threads = []
for qr in sess.graph.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
start_time = time.time()
regression = sess.run('conv_out/sub:0', feed_dict={IS_USING_FEED_INPUT_PLACEHOLDER_NAME+':0':True, FEED_INPUT_PLACEHOLDER_NAME+':0':ldr_img})
elapsed_time = time.time() - start_time
print('elapsed time: %s seconds' % (elapsed_time))
cv2.imwrite(regression_path, regression[0])
if __name__ == '__main__':
inference('G:\\neural_network\\optimizer=Adadelta', 'G:\\input_ldr.png')