I am using Matterport Mask RCNN as my model and I'm trying to build my database for training. After much deliberation over the below problem, I think what I'm actually asking is how do I add more than one class (+ BG)?
I get the following AssertionError
:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-21-c20768952b65> in <module>()
15
16 # display image with masks and bounding boxes
---> 17 display_instances(image, bbox, masks, class_ids/4, train_set.class_names)
/usr/local/lib/python3.6/dist-packages/mask_rcnn-2.1-py3.6.egg/mrcnn/visualize.py in display_instances(image, boxes, masks, class_ids, class_names, scores, title, figsize, ax, show_mask, show_bbox, colors, captions)
103 print("\n*** No instances to display *** \n")
104 else:
--> 105 assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
106
107 # If no axis is passed, create one and automatically call show()
AssertionError:
The problem appears to come from this mask.shape[-1] == class_ids.shape[0]
resulting in False
which should not be the case.
I have now traced it back to the masks.shape[-1]
is 4 times the value of the class_id.shape[0]
and I think this may have something to do with having 4 classes in the data. Unfortunately, I haven't worked out how to solve this problem.
# load the masks for an image
def load_mask(self, image_id):
# get details of image
info = self.image_info[image_id]
# define box file location
path = info['annotation']
# load XML
boxes, w, h = self.extract_boxes(path)
# create one array for all masks, each on a different channel
masks = zeros([h, w, len(boxes)], dtype='uint8')
# create masks
class_ids = list()
for i in range(len(boxes)):
box = boxes[i]
row_s, row_e = box[1], box[3]
col_s, col_e = box[0], box[2]
masks[row_s:row_e, col_s:col_e, i] = 1
class_ids.append(self.class_names.index('Resistor'))
class_ids.append(self.class_names.index('LED'))
class_ids.append(self.class_names.index('Capacitor'))
class_ids.append(self.class_names.index('Diode'))
return masks, asarray(class_ids, dtype='int32')
# load the masks and the class ids
mask, class_ids = train_set.load_mask(image_id)
print(mask, "and", class_ids)
# display image with masks and bounding boxes
display_instances(image, bbox, mask, class_ids, train_set.class_names)