I want to fine-tune GoogLeNet to do Multi-label classification with Caffe. I have already fine-tuned it to a single-label classification, but I can't make the transition to Multi-label yet.
The main steps I am doing which are different:
Create LMDB for Data & Ground Truth
I am modifying the code here and here to create one LMDB with the Data and the other with ground truth.
Replacing SoftmaxWithLoss with SigmoidCrossEntropyLoss
Updating the train_val.prototxt, I replace SoftmaxWithLoss layers to SigmoidCrossEntropyLoss, and set the data layers so that both DB's are loaded. I set the learning rate parameters as I have done with the single-label classification problem.
This steps seems to be working. The data flow, and it is possible to perform solver.step(1). To verify the data and labels are loaded right, I have explicitly calculated the loss using the formula, and got the same result as Caffe.
Problem
The network does not converge. Running it several hundrads of iterations, each of the different classes averages around the class population. That is if class a has 0.35 1's and 0.65 0's in the population, the network will converge to ~0.35 classification probability for each observation, regardless of the true label.
Possible error 1
I suspect the problem is because I fail to load the images correctly into caffe in a way that GoogLeNet pretrained model can learn from them. My previous experience so far is convert_imageset which works perfectly. Right now I am using shelhamer code to save the images into the LMDB:
im = np.array(Image.open(os.path.join(data_dir,in_)))
im = im[:,:,::-1]
im = im.transpose((2,0,1))
im_dat = caffe.io.array_to_datum(im)
in_txn.put('{:0>10d}'.format(in_idx), im_dat.SerializeToString())
I normalize the mean in the data layer when loading the image. Does that seem right? Is there another way to do that?
Possible error 2
It might also be that the train_val.prototxt has been defined wrong. Is there any thing else that need to be done than switching the SoftmaxWithLoss -> SigmoidCrossEntropyLoss?
Any help will be greatly appreciated! Thanks!