I am trying to implement a multiclass image segmentation task. The mask of the image is of the following kind-
I need to convert the image into labels corresponding to each class where classes are red,green,blue and black(background),which I am doing using the below code-
def mask_to_class(self,mask):
target = torch.from_numpy(mask)
h,w = target.shape[0],target.shape[1]
masks = torch.empty(h, w, dtype=torch.long)
colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()
mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
for k in mapping:
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3)
masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
return masks
The issue is number of distinct colors present in image should come out to be 4 but they are coming 338 in this case, also they are coming irregular in different images too.
Reproduction of the problem could be done using the below code -
image = Image.open("Image_Path")
image = np.array(image)
target = torch.from_numpy(image)
h,w = target.shape[0],target.shape[1]
masks = torch.empty(h, w, dtype=torch.long)
colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
print(colors.shape)
(338, 3)
The shape of colors should be (4,3)
but it is coming 338,3)
.I am unable to find the cause behind it.Below is another image which is identical with the above image with respect to colors present and is giving (4,3)
colors shape which is required.
Where I am doing wrong?