2

I have the following function in python which I am not able to figure out how to express in vectorized form. For me, cov is a numpy array of shape (2,2), mu is the mean vector with shape (2,) and xtp is of shape (~50000,2). I know scipy provides scipy.stats.multivariate_normal.pdf but I am trying to learn how to write efficient vectorized code. Please

def mvnpdf(xtp, mu, cov):
    temp = np.zeros(xtp.shape[0])
    i = 0
    length = xtp.shape[0]
    const = 1 / ( ((2* np.pi)**(len(mu)/2)) * (np.linalg.det(cov)**(1/2)) )
    inv = np.linalg.inv(cov)
    while i < length:
        x = xtp[i]-mu
        exponent = (-1/2) * (x.dot(inv).dot(x))
        temp[i] =  (const * np.exp(exponent))
        i+=1
    return temp
enitihas
  • 835
  • 1
  • 9
  • 17

1 Answers1

4

The only tricky part to vectorize is that double .dot. Let's isolate that:

x = xtp - mu  # move this out of the loop
ddot = [i.dot(inv).dot(i) for i in x]
temp = const * np.exp(-0.5 * ddot)

Put that in your code and see if it produces that same thing.

There are several ways of 'vectorizing' a dot. The one I like to try first is einsum. In my tests this is equivalent:

ddot = np.einsum('ij,jk,ik->i',x,inv,x)

I'd suggest trying it to see if it works and speeds things up. And do play around with these calculations with smaller arrays (not the ~50000) in an interactive shell.

I'm testing things with

In [225]: x
Out[225]: 
array([[  0.,   2.],
       [  1.,   3.],
       ...
       [  7.,   9.],
       [  8.,  10.],
       [  9.,  11.]])
In [226]: inv
Out[226]: 
array([[ 1.,  0.],
       [ 0.,  1.]])

Since this is a learning exercise I'll leave the details to you.

With (2,2), the calculations one cov might be faster it done explicitely rather than with the det and inv functions. But it's that length iteration that's the time consumer.

hpaulj
  • 221,503
  • 14
  • 230
  • 353
  • Thanks a lot. einsum looks awesome. The resulting code is a lot faster .Any good resources to learn on using enisum? – enitihas Sep 12 '15 at 00:43
  • 1
    @enitihas: there are relatively few resources scattered around the web, but here's a SO question with some answers and links: http://stackoverflow.com/questions/26089893/understanding-einsum-numpy – Alex Riley Sep 12 '15 at 09:57