27

I'm looking for the most memory-efficient way to compute the absolute squared value of a complex numpy ndarray

arr = np.empty((250000, 150), dtype='complex128')  # common size

I haven't found a ufunc that would do exactly np.abs()**2.

As an array of that size and type takes up around half a GB, I'm looking for a primarily memory-efficient way.

I would also like it to be portable, so ideally some combination of ufuncs.

So far my understanding is that this should be about the best

result = np.abs(arr)
result **= 2

It will needlessly compute (**0.5)**2, but should compute **2 in-place. Altogether the peak memory requirement is only the original array size + result array size, which should be 1.5 * original array size as the result is real.

If I wanted to get rid of the useless **2 call I'd have to do something like this

result = arr.real**2
result += arr.imag**2

but if I'm not mistaken, this means I'll have to allocate memory for both the real and imaginary part calculation, so the peak memory usage would be 2.0 * original array size. The arr.real properties also return a non-contiguous array (but that is of lesser concern).

Is there anything I'm missing? Are there any better ways to do this?

EDIT 1: I'm sorry for not making it clear, I don't want to overwrite arr, so I can't use it as out.

Ondřej Grover
  • 719
  • 1
  • 5
  • 13

5 Answers5

14

Thanks to numba.vectorize in recent versions of numba, creating a numpy universal function for the task is very easy:

@numba.vectorize([numba.float64(numba.complex128),numba.float32(numba.complex64)])
def abs2(x):
    return x.real**2 + x.imag**2

On my machine, I find a threefold speedup compared to a pure-numpy version that creates intermediate arrays:

>>> x = np.random.randn(10000).view('c16')
>>> y = abs2(x)
>>> np.all(y == x.real**2 + x.imag**2)   # exactly equal, being the same operation
True
>>> %timeit np.abs(x)**2
10000 loops, best of 3: 81.4 µs per loop
>>> %timeit x.real**2 + x.imag**2
100000 loops, best of 3: 12.7 µs per loop
>>> %timeit abs2(x)
100000 loops, best of 3: 4.6 µs per loop
burnpanck
  • 1,955
  • 1
  • 12
  • 36
  • 1
    I'd like to accept this as an answer, but I'm not sure how portable it is. Numba is pretty easy to install these days with Anaconda on most machines, but I'm not sure how portable across architectures the actual LLVM bindings are. Perhaps you could add some info about the portability of this answer. – Ondřej Grover Mar 21 '17 at 09:03
  • Well, I'm LLVM expert, but the documentation of the current version (0.31.0) says: Supported are Linux, Windows 7 and OS X 10.9 and later. – burnpanck Mar 21 '17 at 14:52
5

EDIT: this solution has twice the minimum memory requirement, and is just marginally faster. The discussion in the comments is good for reference however.

Here's a faster solution, with the result stored in res:

import numpy as np
res = arr.conjugate()
np.multiply(arr,res,out=res)

where we exploited the property of the abs of a complex number, i.e. abs(z) = sqrt(z*z.conjugate), so that abs(z)**2 = z*z.conjugate

gg349
  • 21,996
  • 5
  • 54
  • 64
  • 4
    I was also thinking about this, but this has the problem that the result is still complex. Additionally, the peak memory consumption is 2.0 * original array size. I could simply take the real part (as the imag part should be very close to 0) , but that would either further increase the peak memory consumption or give me a non-contiguous array. Also, the multiplication of complex numbers will perform many unnecessary multiplications and additions that we already know have no use (as they cancel out to 0). – Ondřej Grover May 25 '15 at 13:54
  • 1) the result is real-valued, with a complex `dtype`, which is different; 2) the memory consumption is not twice, we only allocate once, for `res`, which is unavoidable, and then use `out` for `multiply()`; 3) notice that `all(res.imag==0)->True`, so that there is NO imaginary part at all; 4) you cannot think of complex to complex multiplication as 4 real-real multiplications and conclude there are time-consuming calculations. The code is faster then using `abs()` and this is what is asked. If you wonder why is that, this likely boils down to how CPUs implement complex numbers multiplication – gg349 May 25 '15 at 14:33
  • Even though it is real-valued (in theory), it still takes up memory for all the zero imaginary parts. I was talking about how much memory I need to get the final (real) result, assuming I don'| want to overwrite arr. The minimum is 1.5 * arr size. Your suggestion is 2.0, because it takes up memory for the zero imaginary parts too. Relying on CPU optimizations is not very portable (although it would be hard to find a PC that would not have theme these days). – Ondřej Grover May 25 '15 at 14:58
  • About the CPU 'optimizations', this is more about relying on `numpy` to have a decent performance over most platforms, and up to you to choose a platform with a reasonable floating point support. Anyway I see the point about the 2x memory requirement. – gg349 May 25 '15 at 15:10
  • In my time tests the `conjugate` method is half the speed of `arr.real**2+arr.imag**2`. It's doing 2x the number of multiplications, since the `multiply` does not short-circuit the terms that produce an imaginary value. And `abs**2` is about the same time as the conjugate multiply. – hpaulj May 25 '15 at 19:27
  • @hpaulj then I was wrong and the extra 2 unnecessary multiplications are time-consuming. On my machine this `conjugate` method is faster than `abs()**2` on all runs, but the gap is small (~5%). I'll however keep the answer and clarify, also because of these not trivial comments. – gg349 May 26 '15 at 04:59
1

If your primary goal is to conserve memory, NumPy's ufuncs take an optional out parameter that lets you direct the output to an array of your choosing. It can be useful when you want to perform operations in place.

If you make this minor modification to your first method, then you can perform the operation on arr completely in place:

np.abs(arr, out=arr)
arr **= 2

One convoluted way that only uses a little extra memory could be to modify arr in place, compute the new array of real values and then restore arr.

This means storing information about the signs (unless you know that your complex numbers all have positive real and imaginary parts). Only a single bit is needed for the sign of each real or imaginary value, so this uses 1/16 + 1/16 == 1/8 the memory of arr (in addition to the new array of floats you create).

>>> signs_real = np.signbit(arr.real) # store information about the signs
>>> signs_imag = np.signbit(arr.imag)
>>> arr.real **= 2 # square the real and imaginary values
>>> arr.imag **= 2
>>> result = arr.real + arr.imag
>>> arr.real **= 0.5 # positive square roots of real and imaginary values
>>> arr.imag **= 0.5
>>> arr.real[signs_real] *= -1 # restore the signs of the real and imagary values
>>> arr.imag[signs_imag] *= -1

At the expense of storing signbits, arr is unchanged and result holds the values we want.

Alex Riley
  • 169,130
  • 45
  • 262
  • 238
  • thank you, however, I don't want overwrite arr, sorry for not making that clear. – Ondřej Grover May 25 '15 at 13:16
  • I see... I can't think of any way to do exactly what you want that (a) preserves `arr`, and (b) allocates only one new array of float values (of the same shape as `arr`). A custom ufunc might be needed (but this might affect portability). – Alex Riley May 25 '15 at 15:15
  • Thank you for your convoluted example. I might have to end up using numexpr. – Ondřej Grover May 25 '15 at 16:46
0

arr.real and arr.imag are only views into the complex array. So no additional memory is allocated.

Daniel
  • 42,087
  • 4
  • 55
  • 81
0

If you don't want sqrt (what should be much heavier than multiply), then no abs.

If you don't want double memory, then no real**2 + imag**2

Then you might try this (use indexing trick)

N0 = 23
np0 = (np.random.randn(N0) + 1j*np.random.randn(N0)).astype(np.complex128)
ret_ = np.abs(np0)**2
tmp0 = np0.view(np.float64)
ret0 = np.matmul(tmp0.reshape(N0,1,2), tmp0.reshape(N0,2,1)).reshape(N0)
assert np.abs(ret_-ret0).max()<1e-7

Anyway, I prefer the numba solution

chao zhang
  • 11
  • 2