7

I have an image of size as RGB uint8(576,720,3) where I want to classify each pixel to a set of colors. I have transformed using rgb2lab from RGB to LAB space, and then removed the L layer so it is now a double(576,720,2) consisting of AB.

Now, I want to classify this to some colors that I have trained on another image, and calculated their respective AB-representations as:

Cluster 1: -17.7903  -13.1170
Cluster 2: -30.1957   40.3520
Cluster 3:  -4.4608   47.2543
Cluster 4:  46.3738   36.5225
Cluster 5:  43.3134  -17.6443
Cluster 6:  -0.9003    1.4042
Cluster 7:   7.3884   11.5584

Now, in order to classify/label each pixel to a cluster 1-7, I currently do the following (pseudo-code):

clusters;
for each x
  for each y
    ab = im(x,y,2:3);
    dist = norm(ab - clusters); // norm of dist between ab and each cluster
    [~, idx] = min(dist);
  end
end

However, this is terribly slow (52 seconds) because of the image resolution and that I manually loop through each x and y.

Are there some built-in functions I can use that performs the same job? There must be.

To summarize: I need a classification method that classifies pixel images to an already defined set of clusters.

Divakar
  • 218,885
  • 19
  • 262
  • 358
casparjespersen
  • 3,460
  • 5
  • 38
  • 63
  • Any runtime comparison on how well the accepted solution works against the actual code for the pseudo one that you had posted, if you did code it? – Divakar Nov 18 '14 at 15:19
  • @Divakar Yes, and it was actually quite interesting. My first attempt: 52 seconds. My first attempt, but migrated to use parallel computing (4 pools): 10 seconds. Approach #1: 0.06 seconds. Quite amazing. – casparjespersen Nov 19 '14 at 12:43
  • And with approach #2, have you had a chance to try out that too? Sorry to be noisy about these figures, but these kinda get me excited :) – Divakar Nov 19 '14 at 12:47
  • Hehe, it's cool :) I really enjoy the fact that matrix programming even outperforms parallel computing by that much! Is there any limit as to how big matrixes can be in order to stay in the memory? I haven't tried Approach #2 yet, but I can do it later today to check it out. – casparjespersen Nov 19 '14 at 12:50
  • Well with approach #1, you would get to the memory bandwidth limit soon, but with approach #2, it should hold better with large datasizes. I would be keen to know about runtime comparisons for approach #2, specially for large datasizes, if you would like to test out, let me know! By the way that great speedup magic with the matrix programming which we call as the vectorization had a big player in `bsxfun`, the most versatile tool for vectorization! – Divakar Nov 19 '14 at 12:52
  • @Divakar Approach #1 ranges from 0.06-0.09 seconds whereas Approach #2 ranges from 0.04-0.06 seconds. So it's a wee bit faster. – casparjespersen Nov 19 '14 at 13:14
  • Awesome, thanks for getting those numbers! Kool! – Divakar Nov 19 '14 at 13:17

2 Answers2

11

Approach #1

For a N x 2 sized points/pixels array, you can avoid permute as suggested in the other solution by Luis, which could slow down things a bit, to have a kind of "permute-unrolled" version of it and also let's bsxfun work towards a 2D array instead of a 3D array, which must be better with performance.

Thus, assuming clusters to be ordered as a N x 2 sized array, you may try this other bsxfun based approach -

%// Get a's and b's
im_a = im(:,:,2);
im_b = im(:,:,3);

%// Get the minimum indices that correspond to the cluster IDs
[~,idx]  = min(bsxfun(@minus,im_a(:),clusters(:,1).').^2 + ...
    bsxfun(@minus,im_b(:),clusters(:,2).').^2,[],2);
idx = reshape(idx,size(im,1),[]);

Approach #2

You can try out another approach that leverages fast matrix multiplication in MATLAB and is based on this smart solution -

d = 2; %// dimension of the problem size

im23 = reshape(im(:,:,2:3),[],2);

numA = size(im23,1);
numB = size(clusters,1);

A_ext = zeros(numA,3*d);
B_ext = zeros(numB,3*d);
for id = 1:d
    A_ext(:,3*id-2:3*id) = [ones(numA,1), -2*im23(:,id), im23(:,id).^2 ];
    B_ext(:,3*id-2:3*id) = [clusters(:,id).^2 ,  clusters(:,id), ones(numB,1)];
end
[~, idx] = min(A_ext * B_ext',[],2); %//'
idx = reshape(idx, size(im,1),[]); %// Desired IDs

What’s going on with the matrix multiplication based distance matrix calculation?

Let us consider two matrices A and B between whom we want to calculate the distance matrix. For the sake of an easier explanation that follows next, let us consider A as 3 x 2 and B as 4 x 2 sized arrays, thus indicating that we are working with X-Y points. If we had A as N x 3 and B as M x 3 sized arrays, then those would be X-Y-Z points.

Now, if we have to manually calculate the first element of the square of distance matrix, it would look like this –

first_element = ( A(1,1) – B(1,1) )^2 + ( A(1,2) – B(1,2) )^2         

which would be –

first_element = A(1,1)^2 + B(1,1)^2 -2*A(1,1)* B(1,1)   +  ...
                A(1,2)^2 + B(1,2)^2 -2*A(1,2)* B(1,2)    … Equation  (1)

Now, according to our proposed matrix multiplication, if you check the output of A_ext and B_ext after the loop in the earlier code ends, they would look like the following –

enter image description here

enter image description here

So, if you perform matrix multiplication between A_ext and transpose of B_ext, the first element of the product would be the sum of elementwise multiplication between the first rows of A_ext and B_ext, i.e. sum of these –

enter image description here

The result would be identical to the result obtained from Equation (1) earlier. This would continue for all the elements of A against all the elements of B that are in the same column as in A. Thus, we would end up with the complete squared distance matrix. That’s all there is!!

Vectorized Variations

Vectorized variations of the matrix multiplication based distance matrix calculations are possible, though there weren't any big performance improvements seen with them. Two such variations are listed next.

Variation #1

[nA,dim] = size(A);
nB = size(B,1);

A_ext = ones(nA,dim*3);
A_ext(:,2:3:end) = -2*A;
A_ext(:,3:3:end) = A.^2;

B_ext = ones(nB,dim*3);
B_ext(:,1:3:end) = B.^2;
B_ext(:,2:3:end) = B;

distmat = A_ext * B_ext.';

Variation #2

[nA,dim] = size(A);
nB = size(B,1);

A_ext = [ones(nA*dim,1) -2*A(:) A(:).^2];
B_ext = [B(:).^2 B(:) ones(nB*dim,1)];

A_ext = reshape(permute(reshape(A_ext,nA,dim,[]),[1 3 2]),nA,[]);
B_ext = reshape(permute(reshape(B_ext,nB,dim,[]),[1 3 2]),nB,[]);

distmat = A_ext * B_ext.';

So, these could be considered as experimental versions too.

Community
  • 1
  • 1
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • Apologies but my Linear Algebra is a bit rusty. I wish you would create a detailed explanation of the smart solution because you keep posting it and I don't get it completely. Especially `HelpA`, `HelpB` and `helpA * helpB'`. Why do you do `ones(numA,1)`? Why do you use `-2` in `-2*im23(:,id)`? Why do you create the the values of `helpA` and `HelpB` in that order? What is the purpose of `helpA * helpB'`? – kkuilla Nov 18 '14 at 13:37
  • @kkuilla See if the explanation in the edits make sense? – Divakar Nov 18 '14 at 18:22
  • @Divakar Good improvement on `bsxfun`, and very thorough answer! +1 – Luis Mendo Nov 18 '14 at 23:30
  • Oh, yes. Fantastic. Thank you. +50 :-). I didn't see that you re-wrote `(A(1,1)-B(1,1))^2 +((A(1,2)-B(1,2)^2)` using that rule I can't remember the name of. Excellent explanation. – kkuilla Nov 19 '14 at 09:26
  • 1
    @kkuilla Funny thing even I forgot the name to it, elementary school's been long gone, but I think it could be named as expansion of square of subtraction maybe :) – Divakar Nov 19 '14 at 09:30
4

Use pdist2 (Statistics Toolbox) to compute the distances in a vectorized manner:

ab = im(:,:,2:3);                              % // get A, B components
ab = reshape(ab, [size(im,1)*size(im,2) 2]);   % // reshape into 2-column
dist = pdist2(clusters, ab);                   % // compute distances
[~, idx] = min(dist);                          % // find minimizer for each pixel
idx = reshape(idx, size(im,1), size(im,2));    % // reshape result

If you don't have the Statistics Toolbox, you can replace the third line by

dist = squeeze(sum(bsxfun(@minus, clusters, permute(ab, [3 2 1])).^2, 2));

This gives squared distance instead of distance, but for the purposes of minimizing it doesn't matter.

user3666197
  • 1
  • 6
  • 50
  • 92
Luis Mendo
  • 110,752
  • 13
  • 76
  • 147