2

I've created a codebook using k-means of size 4000x300 (4000 centroids, each with 300 features). Using the codebook, I then want to label an input vector (for purposes of binning later on). The input vector is of size Nx300, where N is the total number of input instances I receive.

To compute the labels, I calculate the closest centroid for each of the input vectors. To do so, I compare each input vector against all centroids and pick the centroid with the minimum distance. The label is then just the index of that centroid.

My current Matlab code looks like:

function labels = assign_labels(centroids, X)
labels = zeros(size(X, 1), 1);

% for each X, calculate the distance from each centroid
for i = 1:size(X, 1)
    % distance of X_i from all j centroids is: sum((X_i - centroid_j)^2)
    % note: we leave off the sqrt as an optimization
    distances = sum(bsxfun(@minus, centroids, X(i, :)) .^ 2, 2);
    [value, label] = min(distances);
    labels(i) = label;
end     

However, this code is still fairly slow (for my purposes), and I was hoping there might be a way to optimize the code further.

One obvious issue is that there is a for-loop, which is the bane of good performance on Matlab. I've been trying to come up with a way to get rid of it, but with no luck (I looked into using arrayfun in conjunction with bsxfun, but haven't gotten that to work). Alternatively, if someone know of any other way to speed this up, I would be greatly appreciate it.

Update

After doing some searching, I couldn't find a great solution using Matlab, so I decided to look at what is used in Python's scikits.learn package for 'euclidean_distance' (shortened):

 XX = sum(X * X, axis=1)[:, newaxis]
 YY = Y.copy()
 YY **= 2
 YY = sum(YY, axis=1)[newaxis, :]
 distances = XX + YY
 distances -= 2 * dot(X, Y.T)
 distances = maximum(distances, 0)

which uses the binomial form of the euclidean distance ((x-y)^2 -> x^2 + y^2 - 2xy), which from what I've read usually runs faster. My completely untested Matlab translation is:

 XX = sum(data .* data, 2);
 YY = sum(center .^ 2, 2);
 [val, ~] = max(XX + YY - 2*data*center');
Abe Schneider
  • 977
  • 1
  • 11
  • 23
  • related: [pdist2 equivalent in MATLAB version 7](http://stackoverflow.com/a/7774323/97160) – Amro Jul 10 '12 at 20:26

4 Answers4

4

Use the following function to calculate your distances. You should see an order of magnitude speed up

The two matrices A and B have the columns as the dimenions and the rows as each point. A is your matrix of centroids. B is your matrix of datapoints.

function D=getSim(A,B)
    Qa=repmat(dot(A,A,2),1,size(B,1));
    Qb=repmat(dot(B,B,2),1,size(A,1));
    D=Qa+Qb'-2*A*B';
twerdster
  • 4,977
  • 3
  • 40
  • 70
  • +1 to you. This is indeed faster. Coming to think of it, I don't know why I didn't do this, instead of using `cellfun`. I've been using `cellfun` way too much of late, and should retire it from my toolbag, atleast temporarily :) – abcd Apr 24 '11 at 17:02
1

You can use a more efficient algorithm for nearest neighbor search than brute force. The most popular approach are Kd-Tree. O(log(n)) average query time instead of the O(n) brute force complexity. Regarding a Maltab implementation of Kd-Trees, you can have a look here

Quant
  • 1,593
  • 14
  • 21
1

You can vectorize it by converting to cells and using cellfun:

[nRows,nCols]=size(X);
XCell=num2cell(X,2);
dist=reshape(cell2mat(cellfun(@(x)(sum(bsxfun(@minus,centroids,x).^2,2)),XCell,'UniformOutput',false)),nRows,nRows);
[~,labels]=min(dist);

Explanation:

  • We assign each row of X to its own cell in the second line
  • This piece @(x)(sum(bsxfun(@minus,centroids,x).^2,2)) is an anonymous function which is the same as your distances=... line, and using cell2mat, we apply it to each row of X.
  • The labels are then the indices of the minimum row along each column.
abcd
  • 41,765
  • 7
  • 81
  • 98
  • It looks like num2cell(X) turns each element of X into its own vector. However, for (x-centroids) we want to subtract each row of X against each centroid. So should it read instead: XCell=num2cell(X, 2)? – Abe Schneider Apr 21 '11 at 14:38
  • 1
    @Abe, you're right. While it certainly returns the same answer, there is unnecessary function call overhead in applying it to each element. I've fixed that. But do remember that `bsxfun` and `cellfun` are usually just terse ways of writing a loop, and need not necessarily be faster (sometimes they are, but not always). Timing your loop and the cellfun code for a matrix of the same dimension as yours, they were pretty even at around 93 seconds, differing only in the tenths place. – abcd Apr 21 '11 at 15:46
  • 1
    Great, thanks! Yeah, I was hoping that Matlab's internal loop might be better than their general for-loops, but I'm not seeing a great improvement either. Also the for-loop allows for 'parfor'. Originally wanted use the GPU, but bsxfun doesn't allow foranonymous functions, and thus no way to pass 'centroids'. One improvement I did manage to find was someone else's posting on SO that looped over the centroids instead of the data. I have 4000 centroids and anywhere from 1e5 to 1e6 feature vectors in my data. So by looping on the centroids, I can speed things up with matrix math. – Abe Schneider Apr 21 '11 at 18:04
  • I didn't have space in my previous comment, but I used the label assignment code from here: http://stackoverflow.com/questions/1373516/matlabk-means-clustering/1400760#1400760 – Abe Schneider Apr 21 '11 at 18:06
1

For a true matrix implementation, you may consider trying something along the lines of:

  P2 = kron(centroids, ones(size(X,1),1));
  Q2 = kron(ones(size(centroids,1),1), X);

  distances = reshape(sum((Q2-P2).^2,2), size(X,1), size(centroids,1));

Note This assumes the data is organized as [x1 y1 ...; x2 y2 ...;...]

jmetz
  • 12,144
  • 3
  • 30
  • 41
  • This uses a lot of memory - as any truly vectorized version will. The (slower) alternative is already provided by @yoda. Also you may be able to replace kron with repmat - this was just a version I had already used in one of my own projects – jmetz Apr 22 '11 at 14:57
  • Hmm. I'm going to try this, since my current approach is going to take > 4 days. I'm not familiar with kron, though, so I'm going to attempt to rewrite it using repmats above. If you have a chance, do you mind verifying? – Abe Schneider Apr 22 '11 at 19:39
  • Nevermind. On paper, it looks like if I make both matrices 3 dimensions (repeat the data matrix for size(centroids, 1) times, and repeat the each centroid size(data, 1) times for each centroid, I can just subtract, square, sum, and then min them. However, I can't figure out a good way to make sure second matrix. – Abe Schneider Apr 22 '11 at 19:55
  • Sure - could you give me an idea how big `N` will be? As I mentioned this is memory usage can get big pretty quickly - I don't think it'll handle N much more than a few hundred or so max, given the 300-d measurement space. – jmetz Apr 22 '11 at 19:58
  • Okay, good to know. I'm reading up on krons now, but I haven't seen it used this way before. As for N, it's fairly large, so that might be an issue. My data is generally 100,000 and my centroids are around 4,000. I can change the size of the data since I'm paging it in, though if it's too small, I worry that will incur a high cost. – Abe Schneider Apr 22 '11 at 20:07
  • This unnecessarily stores the data which becomes too costly for large arrays. Each point in the distance matrix is just the dot product which can be written as |x|^2 -2*x'*y+|y|^2 which in turn is quite easy to turn into a matrix calculation. – twerdster Apr 25 '11 at 17:19
  • @mutzmatron - He doesnt need to iterate through the datasets at all. Check my answer for the code. – twerdster Apr 25 '11 at 22:24
  • @mutzmatron Your method stores data unnecessarily. For centroids of size 4000x300 and X of size 10000x300 my method uses about 650mb and a few seconds. Your method immediately runs out of memory. – twerdster Apr 27 '11 at 19:51
  • @twerdster - sincere apologies - you are right, your code doesn't store the intermediate matrices before the sums, thereby avoiding the overhead I mentioned. +1! – jmetz Apr 27 '11 at 20:32