2

Recently, I came across a really interesting question:

Given the number N, how many combinations exist that can be written as the sum of several distinct squared numbers?

For example, 25 can be written as:

25 = [5**2 , 3**2 + 4**2]

So the answer would be 2.

But for 31, there exist no solutions, so the answer would be 0.

At the beginning, I thought this would be an easy problem to solve. I just assumed that if I count every combination of squared numbers less than square root of number N, then I should be fine, time-wise. (I found the combinations based on this)

def combs(a):
    if len(a) == 0:
        return [[]]
    cs = []
    for c in combs(a[1:]):
        cs += [c, c+[a[0]]]
    return cs

def main(n):
    k = 0
    l = combs(range(1, int(n**0.5)+1))
    for i in l:
        s = 0
        for j in i:
            s += j**2
        if s == n:
            print(i, n)
            k += 1
    print('\nanswer: ',k)

if __name__ == '__main__':
    main(n = 25)

However, if you replicate my solution, you'll see that after N > 400, the problem gets basically impossible to solve in a short time. So, I was wondering if any optimization exist for this problem?

  • 2
    You can use a recursive solution along with memoization. Subtract the square of a number, then call the function recursively on that number. – Barmar Mar 30 '22 at 21:40
  • 1
    Perhaps this may help: https://i.imgur.com/1yyxNk0.png. I'll write up a full solution later. EDIT: Forgot to include 5^2 in the image, but basically that's just another YIELD. – Mateen Ulhaq Mar 30 '22 at 22:15
  • @MateenUlhaq I understood your solution. Nevertheless, it would be kind of you if you add the full solution for everyone. Many thanks. – Amirhossein Rezaei Mar 31 '22 at 00:29

5 Answers5

5

You can use a standard single-use "coin change" algorithm, with the values of the coins being squares.

def sum_distinct_squares(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

print(sum_distinct_squares(100000))

This runs in O(n * sqrt(n)) time, and solves the n=400 case in 0.016s on my machine, and n=100000 in 1.715s.

Paul Hankin
  • 54,811
  • 11
  • 92
  • 118
  • 1
    [A variation](https://tio.run/##bY4xDoMwDEX3nMIjFAagS4XUczAghKISioc4wTFDT58GSjv1L5ae/Wz7lyyOrjfPMc7sLFgtC6D1jgUwrCzqwM4b1uL429LTpNRkZgibHScMgvSQMaybZhMyylsFKR3coa8HKKCvBrgAHXTe1wASsKanyerycyhZRX2Ke7BJNiYLf6jrsWmHhK32WXqhPEmq@THERjamRGlQyjOSZH8frKs9eR7jGw), seems a bit faster. – Kelly Bundy Mar 31 '22 at 07:36
  • 1
    Even faster by going [from large squares to small](https://tio.run/##bY4xDoMwDEX3nMIjtFQCulSVeg4GFKGoCcVDnOCYoaenAdFO/YulZz/b8S1ToOst8rqOHDx4IxOgj4EFMM0saschOjYS@Nsy1ipl3Qhp8YPFJEhPGdK8GHapoPKuIKeDB/SNhjP0tYYT0E7HbQ0gARt6uWK/kpUK6gouzaFuwTb7mD38oa7H9q4z9iYW@YnqILmW@xA7WZgyJa1UZCQp/r7Y1FvKcl0/) (I guess the numbers stay smaller for longer). – Kelly Bundy Mar 31 '22 at 07:41
  • @KellyBundy I've added another optimization based on this answer. – Amirhossein Rezaei Mar 31 '22 at 10:22
1

The following MiniZinc model copes with n=400 (55 solutions) in less than a second:

int: n = 25;
set of int: Domain = 1..ceil(pow(n, 0.5));

%  the array of decision variable decides
%  which integers between 1 and n^0.5 are added as squares
array[Domain] of var bool: b;

constraint n == sum([b[i] * i * i | i in Domain]);

output ["\(n) = "] ++ [if fix(b[i]) then "+\(i)²" else "" endif | i in Domain];

To evaluate the solutions for n=400 by brute force:
Enumerate the 1,048,576 20-bit integers and register those as solutions which yield the desired sum. Each of the twenty bits decides, which integer 1..20 should be squared and added. It does not take that long to loop through a million cases.

Axel Kemper
  • 10,544
  • 2
  • 31
  • 54
  • This seems very interesting... yet, I'm not familiar with the language. I think a more algorithmic solution would be more appropriate as an acceptable answer. – Amirhossein Rezaei Mar 30 '22 at 22:22
1

Implementation:

from functools import cache
from math import sqrt

@cache
def _square_sums(n, max_i):
    if n == 0:
        return 1
    start = min(max_i, int(sqrt(n)))
    return sum(_square_sums(n - i**2, i - 1) for i in range(start, 0, -1))

def square_sums(n):
    return _square_sums(n, max_i=n)

Tests:

>>> square_sums(25)
2

>>> square_sums(55)
1

>>> square_sums(400)
55

>>> square_sums(10000)
3296089777

>>> square_sums(100000)
2759256389896728737285379

Performance characteristics:

Fast for n < 10000, but slow for much larger n.

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
  • 1
    `list(square_sums(55))` is `[]` but 1² + 2² + 3² + 4² + 5² = 55. Your stopping condition enforces ordering, but unfortunately excludes valid solutions. – Paul Hankin Mar 31 '22 at 07:00
  • 1
    The wrong assumption is that the largest square i² in a valid sum is always such that i >= sqrt(n)/2. But the sum of the first k squares grows like k³/6, so for large n there can be solutions with squares that are arbitrarily smaller than sqrt(n). – Paul Hankin Mar 31 '22 at 07:31
  • 1
    @PaulHankin Fixed by adding a [`max_i` parameter](https://stackoverflow.com/revisions/71685979/2). To speed things up, I [abandoned enumerating](https://stackoverflow.com/revisions/71685979/5) all the solutions. Though it's still quite a bit slower than your solution. – Mateen Ulhaq Mar 31 '22 at 08:28
1

Based on the @PaulHankin 's excellent answer, another optimization is also possible using the fast Numba JIT compiler:

import time
import numba
import matplotlib.pyplot as plt

@numba.jit
def sum_distinct_squares_jit(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

def sum_distinct_squares_non_jit(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

def loop(N, jit = None):
    if jit is True:
        sds = sum_distinct_squares_jit
    if jit is False:
        sds = sum_distinct_squares_non_jit
  
    times = []
    t_i = time.clock()
    for number in range(1, N):
        sds(number)
        times.append(time.clock() - t_i)
    plt.plot(times)

loop(N = 10_000, jit = True)
loop(N = 10_000, jit = False)
plt.legend(['JIT', 'None JIT'])
plt.show()

Running this in a for loop and plotting the time it takes for JIT and Non-JIT, we get:

enter image description here

So, using JIT is significantly faster. However, this comes with a price: Numba doesn't support the bigint library and for big numbers using Numba is not valid.

  • 1
    Until what N *is* Numba valid? – Kelly Bundy Mar 31 '22 at 10:35
  • @KellyBundy I can't say with certainty in which "N" the numba breaks (since any number has it's own solution) but if you could have a vague idea on the relationship between the order of N and order of solution, then based on the fact that the max integer width is currently limited to 64-bit in numba, then you could find the maximum limit of this approach. – Amirhossein Rezaei Mar 31 '22 at 10:50
  • 1
    Ok, then it [works up to N=54443](https://tio.run/##bZDNasMwEITveoo9Wq4LtltKCaSvkYMRQa3lZglaKav1oU/vSmr6B5nLwjc7I7HxQ06BHp4jb9vCwYO3cgL0MbAApguLqjhEx1YCf1t2npWa3QJp9ccZkyC9yTFdVssuNaR3CrIOsIdpMHAHU2@gBap0KTWABGzp3TX1lRzpoO/gfrhGi3DMecw5/EGHCcedydjb2ORPdFeSp75VTn/rlrJs4GUPY9s@Pf46RZGRpMFaaPQ/65WdPVfCTlamvEJGqa/EzQMMfZHW2/YJ). – Kelly Bundy Mar 31 '22 at 11:23
0

There is only a finite set of numbers not sum of distinct squares (http://oeis.org/A001422)

So here is an O(1) solution :

If N not in [2, 3, 6, 7, 8, 11, 12, 15, 18, 19, 22, 23, 24, 27, 28, 31, 32, 33, 43, 44, 47, 48, 60, 67, 72, 76, 92, 96, 108, 112, 128]:
  Print("N is a sum of distinct squares")
Jean Valj
  • 119
  • 4
  • I wouldn't call that a "solution", as the question asks for the *number* of combinations, not just for whether there is one. – Kelly Bundy Apr 03 '22 at 10:29
  • I think [this](http://oeis.org/A033461) is the correct sequence. Maybe its formulas/codes could be used here, but I'm not familiar enough with those languages to decipher them :-) – Kelly Bundy Apr 03 '22 at 10:35
  • @KellyBundy True but the title is asking for less ;-) – Jean Valj Apr 03 '22 at 13:41