I trained a FasterRCNN for point detection using TensorFlow and its python API. I was able to do the inference successfully in the same without any issues. Now for production requirements, I need to build a standalone application or library which can perform the inference in C++.
In order to achieve that I successfully managed to build the TensorFlow library from scratch for C++ API and ran a few test programs.
Steps followed for inference on my trained model:
- Method #1: freeze the graph to *.pb file using modified freeze_graph.py and restore it.
- Method #2: Directly restore from checkpoint using MetaGraphDef and ReadBinaryProto.
I tried both the methods and the loading of the model seems to happen without any error. I printed the node names as well.
After loading the graph I created a handler for image input, previously the input during training was getting handled by QueueRunner, but because it's just inference I created the handler myself with appropriate modifications to it as can be seen below.
Status fasterRcnn::CreateGraphForImage(bool unstack)
{
fileNameVar=Placeholder(iRoot.WithOpName("input"),DT_STRING);
auto fileReader=ReadFile(iRoot.WithOpName("fileReader"),fileNameVar);
auto imageReader=DecodeJpeg(iRoot.WithOpName("jpegReader"),fileReader,DecodeJpeg::Channels(imageChannels));
auto floatCaster=Cast(iRoot.WithOpName("floatCaster"),imageReader,DT_FLOAT);
auto dimsExpander=ExpandDims(iRoot.WithOpName("dim"),floatCaster,0);
//auto resized=ResizeBilinear(iRoot.WithOpName("size"),dimsExpander,Const(iRoot,{imageSide,imageSide}));
//auto div=Div(iRoot.WithOpName("normalized"),resized,{255.f});
imageTensorVar=dimsExpander;//div;
return iRoot.status();
}
NOTE: iRoot is the private variable of fasterRcnn class-> Scope used by the graph for loading images into tensors
class fasterRcnn
{
private:
Scope iRoot;//graph for loading images into tensors
const int imageSide;
const int imageChannels;//rgb
//load image vars
Output fileNameVar;
Output imageTensorVar;
//
std::unique_ptr<Session> fSession;//file session
std::unique_ptr<GraphDef> graphDef;
public:
fasterRcnn(int side,int channels):iRoot(Scope::NewRootScope()),imageSide(side),imageChannels(channels){}
Status CreateGraphForImage(bool unstack);
Status LoadSavedModel(string &fileName);//via frozen graph
Status LoadSavedModel(std::string graph_fn,std::string checkpoint_fn);//via checkpoints
Status PredictFromFrozen(Tensor &image,int&results);//
Status ReadTensorFromImageFile(string& file_name, Tensor& outTensor);
Status ReadFileTensors(string& folder_name,vector<Tensor>& file_tensors);
};
Now the when the time comes for Prediction when I run Session->Run with particular input node name and output node name the program hangs:
Status fasterRcnn::PredictFromFrozen(Tensor& image, int& result)
{
vector<Tensor> outTensors;
cout<<"Prediction about to start"<<endl;
TF_CHECK_OK(fSession->Run({{"fasterrcnn/truncated_base_network/sub",image}},{"fasterrcnn/rcnn/rcnn_proposal/GatherV2_1"},{},&outTensors));
cout<<"Prediction done"<<endl;
auto output_c = outTensors[0].scalar<float>();
for (int i=0;i<outTensors.size();i++)
{
cout << "Output dimension of the image" << outTensors[i].DebugString()<<"\n";
}
cout << "Output dimension of the image" << outTensors[0].DebugString()<<"\n";
return Status::OK();
}
TF_CHECK_OK(fSession->Run({{"fasterrcnn/truncated_base_network/sub",image}},{"fasterrcnn/rcnn/rcnn_proposal/GatherV2_1"},{},&outTensors));
It hangs at the above statement.
Up until a few Debug thoughts:
- I thought maybe the QueueRunner might be creating the issue, but according to this answer I handled that as I created the image handler graph.
- I thought maybe I didn't freeze the graph well, that's I loaded directly from the checkpoints, but still, it hangs.
- The input to the VGG_16 truncated network in graph expects a 4D float tensor, which I am providing, double-checked that.
Question:
- Is there any step that I missing, for successful inferencing in C++.
- How to verify if my input and output node names are correct? I referred to the python code and they seem right.
- Any steps for debugging that I should follow to provide more info on this?