13

Given two sets of d-dimensional points. How can I most efficiently compute the pairwise squared euclidean distance matrix in Matlab?

Notation: Set one is given by a (numA,d)-matrix A and set two is given by a (numB,d)-matrix B. The resulting distance matrix shall be of the format (numA,numB).

Example points:

d = 4;            % dimension
numA = 100;       % number of set 1 points
numB = 200;       % number of set 2 points
A = rand(numA,d); % set 1 given as matrix A
B = rand(numB,d); % set 2 given as matrix B
matheburg
  • 2,097
  • 1
  • 19
  • 46

2 Answers2

20

The usually given answer here is based on bsxfun (cf. e.g. [1]). My proposed approach is based on matrix multiplication and turns out to be much faster than any comparable algorithm I could find:

helpA = zeros(numA,3*d);
helpB = zeros(numB,3*d);
for idx = 1:d
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,idx), A(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [B(:,idx).^2 ,    B(:,idx), ones(numB,1)];
end
distMat = helpA * helpB';

Please note: For constant d one can replace the for-loop by hardcoded implementations, e.g.

helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,1), A(:,1).^2, ... % d == 2
                          ones(numA,1), -2*A(:,2), A(:,2).^2 ];   % etc.

Evaluation:

%% create some points
d = 2; % dimension
numA = 20000;
numB = 20000;
A = rand(numA,d);
B = rand(numB,d);

%% pairwise distance matrix
% proposed method:
tic;
helpA = zeros(numA,3*d);
helpB = zeros(numB,3*d);
for idx = 1:d
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,idx), A(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [B(:,idx).^2 ,    B(:,idx), ones(numB,1)];
end
distMat = helpA * helpB';
toc;

% compare to pdist2:
tic;
pdist2(A,B).^2;
toc;

% compare to [1]:
tic;
bsxfun(@plus,dot(A,A,2),dot(B,B,2)')-2*(A*B');
toc;

% Another method: added 07/2014
% compare to ndgrid method (cf. Dan's comment)
tic;
[idxA,idxB] = ndgrid(1:numA,1:numB);
distMat = zeros(numA,numB);
distMat(:) = sum((A(idxA,:) - B(idxB,:)).^2,2);
toc;

Result:

Elapsed time is 1.796201 seconds.
Elapsed time is 5.653246 seconds.
Elapsed time is 3.551636 seconds.
Elapsed time is 22.461185 seconds.

For a more detailed evaluation w.r.t. dimension and number of data points follow the discussion below (@comments). It turns out that different algos should be preferred in different settings. In non time critical situations just use the pdist2 version.

Further development: One can think of replacing the squared euclidean by any other metric based on the same principle:

help = zeros(numA,numB,d);
for idx = 1:d
    help(:,:,idx) = [ones(numA,1), A(:,idx)     ] * ...
                    [B(:,idx)'   ; -ones(1,numB)];
end
distMat = sum(ANYFUNCTION(help),3);

Nevertheless, this is quite time consuming. It could be useful to replace for smaller d the 3-dimensional matrix help by d 2-dimensional matrices. Especially for d = 1 it provides a method to compute the pairwise difference by a simple matrix multiplication:

pairDiffs = [ones(numA,1), A ] * [B'; -ones(1,numB)];

Do you have any further ideas?

matheburg
  • 2,097
  • 1
  • 19
  • 46
  • Really interesting!+1 In an other story: On my machine starting at about `d>30`, `bsxfun` will win again due to lower memory overhead. – knedlsepp Jan 30 '15 at 13:00
  • @knedlsepp Thanks for taking time to put all those together! Well I did benchmark those two vectorized versions again the loop-based version as proposed here and I didn't see a lot of difference, at least not for small to decent sized `dims`. – Divakar Jan 31 '15 at 09:22
  • @Divakar: As on my machine: If we want squared distances, your `Vec1` version is the fastest for lower dimensions until it gets beat by `bsxfun`. If we want the actual `sqrt`-distances `pdist2` is faster until it also gets beat by `bsxfun` eventually. After doing all this comparing: I guess, even though it's nice to know that we can squeeze the last bit of speed from all of this, I somehow get the feeling that simply going with `pdist2` is a no-brainer, if you have the statistics toolbox installed, as it is flexible yet still very very fast. – knedlsepp Jan 31 '15 at 13:11
  • @knedlsepp Thanks a lot - this is a very interesting evaluation! I would just like to add that the time scale in log10 is a little bit misleading since the relevance of computation time does not live on a log-scale (e.g. a factor 2 is really interesting to save time, but looks like nothing on log10-scale). My facit: It pays off to test different algos for a time critical implementation (which is the case especially for large numbers of points). E.g. for large numbers of 2d data points it turns out to be useful to use my implementation. I really like our collection of algos! :) – matheburg Feb 25 '15 at 07:24
  • @Divakar Your vectorized variations are interesting variations of the matrix approach! *thumbs up* :) – matheburg Feb 25 '15 at 07:39
  • @matheburg Thanks! Good to get your approval on those! :) – Divakar Feb 25 '15 at 09:35
  • @matheburg: I usually use `loglog` plots to compare algorithms. (From what I can tell this is also quite common) These have the benefit that one can easier tell if two algorithms are in the same complexity class by looking at their slopes. (Which of course is more theoretically interesting than practically, but I think one can still easily make out which algorithm is fastest) – knedlsepp Feb 25 '15 at 10:49
  • 1
    This is a very interesting suggestion and comparison. It seems that the pdist2 version lacks in efficiency due mainly to the element-wise squares, while Matlab now provides the 'squaredeuclidean' option to get this directly. With this, the proposed method and pdist2 appear to be very close (and perhaps pdist2 is faster in some regimes). The option may be more recent than the posted answer. – akkapi Aug 15 '21 at 00:18
  • @akkapi Good catch! I would expect the new `pdist2`-option `'squaredeuclidean'` to be at least as efficient as my solution from 2014. Maybe, you (or someone else) could test this and provide a new answer? Unfortunately, I have no access to an up-to-date MATLAB version. – matheburg Aug 15 '21 at 09:17
1

For squared Euclidean distance one can also use the following formula

||a-b||^2 = ||a||^2 + ||b||^2 - 2<a,b>

Where <a,b> is the dot product between a and b

nA = sum( A.^2, 2 ); %// norm of A's elements
nB = sum( B.^2, 2 ); %// norm of B's elements
distMat = bsxfun( @plus, nA, nB' ) - 2 * A * B' ;

Recently, I've been told that as of R2016b this method for computing square Euclidean distance is faster than accepted method.

Shai
  • 111,146
  • 38
  • 238
  • 371