31

I am trying to efficiently compute a summation of a summation in Python:

WolframAlpha is able to compute it too a high n value: sum of sum.

I have two approaches: a for loop method and an np.sum method. I thought the np.sum approach would be faster. However, they are the same until a large n, after which the np.sum has overflow errors and gives the wrong result.

I am trying to find the fastest way to compute this sum.

import numpy as np
import time

def summation(start,end,func):
    sum=0
    for i in range(start,end+1):
        sum+=func(i)
    return sum

def x(y):
    return y

def x2(y):
    return y**2

def mysum(y):
    return x2(y)*summation(0, y, x)

n=100

# method #1
start=time.time()
summation(0,n,mysum)
print('Slow method:',time.time()-start)

# method #2
start=time.time()
w=np.arange(0,n+1)
(w**2*np.cumsum(w)).sum()
print('Fast method:',time.time()-start)
wjandrea
  • 28,235
  • 9
  • 60
  • 81
Adam
  • 433
  • 5
  • 9
  • 2
    The image is unreadable. It is also unclickable (for [the link](https://www.codecogs.com/eqnedit.php?latex=\sum_{x=0}^{100}\left(x^2\sum_{y=0}^x&space;y\right&space;))). Can you fix it? E.g., by providing an *additional* static image. – Peter Mortensen Nov 07 '21 at 12:59
  • 1
    @PeterMortensen [To me it's readable](https://i.stack.imgur.com/4EAeN.png) (though could be better). How does it look to you? – Kelly Bundy Nov 07 '21 at 15:00
  • 1
    @PeterMortensen I [enlarged](https://i.stack.imgur.com/Q2F0U.png) it and fixed the link. Is it readable for you now? – Kelly Bundy Nov 07 '21 at 15:13
  • 9
    Replace 100 by `n` on Wolfram Alpha. You're [done](https://www.wolframalpha.com/input/?i2d=true&i=Sum%5B%5C%2840%29Power%5Bx%2C2%5DSum%5By%2C%7By%2C0%2Cx%7D%5D%5C%2841%29%2C%7Bx%2C0%2Cn%7D%5D). – Eric Duminil Nov 07 '21 at 16:09
  • 1
    @EricDuminil Ha, nice. Would've saved me time. Although if I'm not miscounting, I still have one multiplication less. Do you know why those formulas are written the way they are? – Kelly Bundy Nov 07 '21 at 16:20
  • @KellyBundy: `1/120 n (n + 1) (n + 2) (12 n^2 + 9 n - 1)` helps to see the simple roots, even though the `(12 n^2 + 9 n - 1)` could be factorized further. `n (n (n ((n/10 + 3/8) n + 5/12) + 1/8) - 1/60)` has the same amount of multiplications and additions than yours, I think. – Eric Duminil Nov 07 '21 at 16:36
  • @Peter It was [unreadable on dark mode](https://i.stack.imgur.com/7dW9D.png) so I made the background white to fix that. Is that what you mean? – wjandrea Nov 07 '21 at 18:25
  • @EricDuminil I think that's one multiplication less than mine, but the fractions of course introduce inaccuracies (unless you use real fractions instead of floats, but then you have more operations internally). A relatively small case is [n=2134](https://tio.run/##VY7BDsIgEETv/Yq9FTQNLLSmFz/GGBpNmi1BTDTGb8clVLQcltlHZhj/jJeF7OhDSgRHMGj7xj28O0feBB80sAOCPfSDXNWgvwpX1oEpQilAo5vgbvc5JxBT8RuCFGr2WTXWMIXs5aiMOr4OWjY@XCmKqX3lTu/2D9RuG1r/29BpXk5RFIcsjyl9AA), where with ordinary floats it computes 4433380147350155.0 instead of 4433380147350154. – Kelly Bundy Nov 07 '21 at 19:21
  • @KellyBundy Yes, your formula has the advantage of never computing any float or fraction if you use integer as an input. – Eric Duminil Nov 08 '21 at 07:57
  • @EricDuminil If n weren't an integer, I wouldn't even know how to interpret the question's `\sum_{x=0}^{n}` :-) – Kelly Bundy Nov 08 '21 at 13:43
  • @KellyBundy: Just like for https://en.wikipedia.org/wiki/Factorial and https://en.wikipedia.org/wiki/Gamma_function . As long as they agree on integers, it's nice to be able to use an extended function which also accepts reals or complex numbers. – Eric Duminil Nov 08 '21 at 14:14

5 Answers5

56

Here's a very fast way:

result = ((((12 * n + 45) * n + 50) * n + 15) * n - 2) * n // 120

How I got there:

  1. Rewrite the inner sum as the well-known x*(x+1)//2. So the whole thing becomes sum(x**2 * x*(x+1)//2 for x in range(n+1)).
  2. Rewrite to sum(x**4 + x**3 for x in range(n+1)) // 2.
  3. Look up formulas for sum(x**4) and sum(x**3).
  4. Simplify the resulting mess to (12*n**5 + 45*n**4 + 50*n**3 + 15*n**2 - 2*n) // 120.
  5. Horner it.

Another way to derive it if after steps 1. and 2. you know it's a polynomial of degree 5:

  1. Compute six values with a naive implementation.
  2. Compute the polynomial from the six equations with six unknowns (the polynomial coefficients). I did it similarly to this, but my matrix A is left-right mirrored compared to that, and I called my y-vector b.

Code:

from fractions import Fraction
import math
from functools import reduce

def naive(n):
    return sum(x**2 * sum(range(x+1)) for x in range(n+1))

def lcm(ints):
    return reduce(lambda r, i: r * i // math.gcd(r, i), ints)

def polynomial(xys):
    xs, ys = zip(*xys)
    n = len(xs)
    A = [[Fraction(x**i) for i in range(n)] for x in xs]
    b = list(ys)
    for _ in range(2):
        for i0 in range(n):
            for i in range(i0 + 1, n):
                f = A[i][i0] / A[i0][i0]
                for j in range(i0, n):
                    A[i][j] -= f * A[i0][j]
                b[i] -= f * b[i0]
        A = [row[::-1] for row in A[::-1]]
        b.reverse()
    coeffs = [b[i] / A[i][i] for i in range(n)]
    denominator = lcm(c.denominator for c in coeffs)
    coeffs = [int(c * denominator) for c in coeffs]
    horner = str(coeffs[-1])
    for c in coeffs[-2::-1]:
        horner += ' * n'
        if c:
            horner = f"({horner} {'+' if c > 0 else '-'} {abs(c)})"
    return f'{horner} // {denominator}'

print(polynomial((x, naive(x)) for x in range(6)))

Output (Try it online!):

((((12 * n + 45) * n + 50) * n + 15) * n - 2) * n // 120
Kelly Bundy
  • 23,480
  • 7
  • 29
  • 65
  • 1
    Thanks! While this isn't what I was looking for (the question I asked here was actually an extreme simplification of the real double series I am computing). This does solve the question I asked. I should have specified that I was looking for computational ways to improve the calculation. – Adam Nov 06 '21 at 18:23
  • 2
    @Adam I guess that explains why you used all those functions in the non-NumPy solution, that did look rather odd. Maybe a more general case would still allow similar optimizations, but it depends on how more general. Maybe ask another question with the general formula as you really have it? Like, with `f(x)` and `g(y)` instead of `x^2` and `y` or so, where `f` and `g` are unknown functions (though perhaps some properties are known and could be taken advantage of). – Kelly Bundy Nov 06 '21 at 18:37
  • 16
    @Adam Yeah I think in a case like this where you've simplified your actual problem, it's really important to explain in the question that the given formula is just an example but your goal is really to figure out how to compute a sum quickly, and not to get an actual answer for that specific formula. Otherwise, you run the risk of getting a solution like this one, which is unquestionably the best way to solve the problem you asked but doesn't help you at all with the problem you really have. – David Z Nov 07 '21 at 03:49
19

(fastest methods, 3 and 4, are at the end)

In a fast NumPy method you need to specify dtype=np.object so that NumPy does not convert Python int to its own dtypes (np.int64 or others). It will now give you correct results (checked it up to N=100000).

# method #2
start=time.time()
w=np.arange(0, n+1, dtype=np.object)
result2 = (w**2*np.cumsum(w)).sum()
print('Fast method:', time.time()-start)

Your fast solution is significantly faster than the slow one. Yes, for large N's, but already at N=100 it is like 8 times faster:

start=time.time()
for i in range(100):
    result1 = summation(0, n, mysum)
print('Slow method:', time.time()-start)

# method #2
start=time.time()
for i in range(100):
    w=np.arange(0, n+1, dtype=np.object)
    result2 = (w**2*np.cumsum(w)).sum()
print('Fast method:', time.time()-start)
Slow method: 0.06906533241271973
Fast method: 0.008007287979125977

EDIT: Even faster method (by KellyBundy, the Pumpkin) is by using pure python. Turns out NumPy has no advantage here, because it has no vectorized code for np.objects.

# method #3
import itertools
start=time.time()
for i in range(100):
    result3 = sum(x*x * ysum for x, ysum in enumerate(itertools.accumulate(range(n+1))))
print('Faster, pure python:', (time.time()-start))
Faster, pure python: 0.0009944438934326172

EDIT2: Forss noticed that numpy fast method can be optimized by using x*x instead of x**2. For N > 200 it is faster than pure Python method. For N < 200 it is slower than pure Python method (the exact value of boundary may depend on machine, on mine it was 200, its best to check it yourself):

# method #4
start=time.time()
for i in range(100):
    w = np.arange(0, n+1, dtype=np.object)
    result2 = (w*w*np.cumsum(w)).sum()
print('Fast method x*x:', time.time()-start)
dankal444
  • 3,172
  • 1
  • 23
  • 35
  • 1
    For N=100 000, Slow method: 588.831 s Fast method: 0.0500 s – dankal444 Nov 06 '21 at 16:38
  • 3
    why np.object instead of np.int64 then? – diggusbickus Nov 06 '21 at 16:40
  • 4
    @diggusbickus because `np.int64` has only 64 bits to store integers, Python `int` can be as big as your RAM will allow it. By using "generic" `np.object` you assure that numpy does not convert `int` to `np.int64`. – dankal444 Nov 06 '21 at 16:51
  • 1
    @dankal444 And *real* fast method 0.000001 s :-P – Kelly Bundy Nov 06 '21 at 17:21
  • 6
    I'd try the equivalent non-NumPy version as well, you'll likely find it gets *faster* than the NumPy version. For example `result1 = sum(x*x * ysum for x, ysum in enumerate(itertools.accumulate(range(n+1))))` or `ysum = 0; result1 = sum(x*x * (ysum := ysum + x) for x in range(n+1))` – Kelly Bundy Nov 06 '21 at 19:27
  • 1
    Because at the moment I think it might mislead casual readers into thinking "Ah ok, NumPy makes this fast". Which is wrong, since the NumPy one is faster not because it uses NumPy but because it uses a different algorithm, which the non-NumPy solution could easily use as well, and *should* use for meaningful comparison. – Kelly Bundy Nov 07 '21 at 15:39
  • @KellyBundy I thought you will add this to your ansewer, I can edit mine if you prefer it this way – dankal444 Nov 07 '21 at 15:48
  • 3
    I don't think it fits into mine (since I'm only talking about my approach and would like to keep it that way), but it would make yours better. – Kelly Bundy Nov 07 '21 at 15:50
  • @KellyBundy edited then, thank you, learned a lot by your answers and comments – dankal444 Nov 07 '21 at 16:06
  • 1
    `numpy.object` is not a special NumPy thing. It's just `object`. It's only in the `numpy` namespace for backward compatibility. – user2357112 Nov 07 '21 at 21:17
  • 1
    In general, NumPy cannot optimize any operations on arrays of object dtype. Unless you have a very specific, not-performance-related reason to use an object array, it's usually better to use either a list or an array with a native dtype. – user2357112 Nov 07 '21 at 21:19
  • 1
    Running the correction to a `np.object` gives me this `DeprecationWarning:` `np.object` is a deprecated alias for the `builtin .object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: [https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations](https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations) – ramzeek Nov 08 '21 at 04:37
  • 3
    The pure python version is cheating a bit in the comparison by using `x*x` instead of `x**2` like the other methods. Changing to `x*x` for the numpy solution it is the fastest method for larger n (on my computer). – Forss Nov 12 '21 at 12:17
  • 2
    @Forss Hmm, right. I used `x*x` out of habit, though it's a habit partially because of speed. And now that you pointed out, I was disappointed that NumPy doesn't just figure out once that it's one multiplication and applies that epiphany to the whole array. But then I remembered that we're using `object` and I guess NumPy doesn't make assumptions/typeanalysis there. Without the `dtype`, it does do `x*x` and `x**2` equally fast. With both solutions using `x*x`, NumPy is a bit faster for me at n=1000, about equally fast at n=10000, and a bit slower at n=100000. – Kelly Bundy Nov 12 '21 at 13:44
  • @Forss great you noticed it! edited answer and added your fix as method 4. In my tests for any N>200 numpy x*x is faster than pure python – dankal444 Nov 12 '21 at 19:08
7

Comparing Python with WolframAlpha like that is unfair, since Wolfram will simplify the equation before computing.

Fortunately, the Python ecosystem knows no limits, so you can use SymPy:

from sympy import summation
from sympy import symbols

n, x, y = symbols("n,x,y")
eq = summation(x ** 2 * summation(y, (y, 0, x)), (x, 0, n))
eq.evalf(subs={"n": 1000})

It will compute the expected result almost instantly: 100375416791650. This is because SymPy simplifies the equation for you, just like Wolfram does. See the value of eq:

enter image description here

@Kelly Bundy's answer is awesome, but if you are like me and use a calculator to compute 2 + 2, then you will love SymPy ❤. As you can see, it gets you to the same results with just 3 lines of code and is a solution that would also work for other more complex cases.

Peque
  • 13,638
  • 11
  • 69
  • 105
2

All the answers uses math to simplify or implement the loop in python trying to be cpu optimal, but they are not memory optimal.

Here a naive implementation without using any math simplification which is memory efficient

def function5():
    inner_sum = float()
    result = float()

    for x in range(0, n + 1):
        inner_sum += x
        result += x ** 2 * inner_sum
        
    return result

It is quite slow with respect to the other solutions by dankal444:

method 2   | 31 µs ± 2.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
method 3   | 116 µs ± 538 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
method 4   | 91 µs ± 356 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
function 5 | 217 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

by the way if you jit the function with numba (there may be better options):

from numba import jit
function5 = jit(nopython=True)(function5)

you get

59.8 ns ± 0.209 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
Ruggero Turra
  • 16,929
  • 16
  • 85
  • 141
2

In a comment, you mention that it's really f(x) and g(y) instead of x2 and y. If you're only needing an approximation to that sum, you can pretend the sums are midpoint Riemann sums, so that your sum is approximated by the double integral ∫-.5n+.5 f(x) ∫-.5x+.5 g(y) dy dx.

With your original f(x)=x2 and g(y)=y, this simplifies to n5/10+3n4/8+n3/2+5n2/16+3n/32+1/160, which differs from the correct result by n3/12+3n2/16+53n/480+1/160.

Based on this, I suspect that (actual-integral)/actual would be max(f'',g'')*O(n-2), but I wasn't able to prove it.

Teepeemm
  • 4,331
  • 5
  • 35
  • 58