78

Is there an integer square root somewhere in python, or in standard libraries? I want it to be exact (i.e. return an integer), and raise an exception if the input isn't a perfect square.

I tried using this code:

def isqrt(n):
    i = int(math.sqrt(n) + 0.5)
    if i**2 == n:
        return i
    raise ValueError('input was not a perfect square')

But it's ugly and I don't really trust it for large integers. I could iterate through the squares and give up if I've exceeded the value, but I assume it would be kinda slow to do something like that. Also, surely this is already implemented somewhere?


See also: Check if a number is a perfect square.

Karl Knechtel
  • 62,466
  • 11
  • 102
  • 153
wim
  • 338,267
  • 99
  • 616
  • 750
  • 3
    It's not a requirement that comes up often so there's no built-in. There's nothing wrong with the solution you have, but I'd make one stylistic change - reverse the condition of the `if` so the `return` comes last. – Mark Ransom Mar 13 '13 at 16:22
  • 9
    Can't it overflow/screw up for large inputs because of working with floats? – wim Mar 13 '13 at 16:24
  • 12
    @wim: it can and will. – DSM Mar 13 '13 at 16:24
  • http://code.activestate.com/recipes/577821-integer-square-root-function/\ – NPE Mar 13 '13 at 16:24
  • 2
    It will overflow when `n` becomes too large to fit in a float without truncation, which is at 2**53. Even so it might still work because of the rounding you do to the result. Are you really going to be working with numbers that large? – Mark Ransom Mar 13 '13 at 16:32
  • 1
    Yes I'm going to be working with numbers MUCH larger than 2**53. – wim Mar 13 '13 at 16:40
  • 1
    Precision is the real problem. Since Python supports extended precision integers, doing it in floating point will be a real handicap, since it will limit the magnitude. – Tom Karzes May 09 '16 at 03:20
  • 1
    Python's default `math` library has an integer square root function [`math.isqrt(n)`](https://docs.python.org/3/library/math.html#math.isqrt) – iacob Apr 22 '21 at 10:07

14 Answers14

104

Note: There is now math.isqrt in stdlib, available since Python 3.8.

Newton's method works perfectly well on integers:

def isqrt(n):
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

This returns the largest integer x for which x * x does not exceed n. If you want to check if the result is exactly the square root, simply perform the multiplication to check if n is a perfect square.

I discuss this algorithm, and three other algorithms for calculating square roots, at my blog.

wim
  • 338,267
  • 99
  • 616
  • 750
user448810
  • 17,381
  • 4
  • 34
  • 59
  • 3
    You can get a much better initial approximation using `y = 1 << (n.bit_length()>>1)` (thx to mathmandan). – greggo Jul 06 '15 at 16:32
  • 1
    @greggo This is a good idea, but it isn't compatible with the rest of user448810's algorithm. It causes the function to return 8 for the square root of 100. – Chris Culter Sep 17 '15 at 06:02
  • 2
    Yes, the algorithm as stated requires the x,y estimates to start > than the true square root, and work down. so my formula needs to be adjusted; off the top of my head, `y = 2 << ((n.bit_length()>>1)` should work, may be a way to slice it a bit finer though. – greggo Sep 18 '15 at 14:37
  • 1
    @greggo That nearly works, but fails for n = 2, 3, 4, or 8. `y = (2 ** ((n.bit_length()+1) // 2)) - 1` will work for all non-negative integers, including 0. – clwainwright Dec 03 '15 at 20:28
  • 1
  • Is there a quick modification to the algorithm to acquire the square root that is the smallest square but bigger or equal to the input number? That is `isqrt'(14) == 4` not `3`? – CMCDragonkai Apr 17 '18 at 02:40
  • 2
    @CMCDragonkai: You want `1 + isqrt(n-1)` (assuming that `n >= 1`). – Mark Dickinson Sep 08 '18 at 19:13
44

Update: Python 3.8 has a math.isqrt function in the standard library!

I benchmarked every (correct) function here on both small (0…222) and large (250001) inputs. The clear winners in both cases are gmpy2.isqrt suggested by mathmandan in first place, followed by Python 3.8’s math.isqrt in second, followed by the ActiveState recipe linked by NPE in third. The ActiveState recipe has a bunch of divisions that can be replaced by shifts, which makes it a bit faster (but still behind the native functions):

def isqrt(n):
    if n > 0:
        x = 1 << (n.bit_length() + 1 >> 1)
        while True:
            y = (x + n // x) >> 1
            if y >= x:
                return x
            x = y
    elif n == 0:
        return 0
    else:
        raise ValueError("square root not defined for negative numbers")

Benchmark results:

(* Since gmpy2.isqrt returns a gmpy2.mpz object, which behaves mostly but not exactly like an int, you may need to convert it back to an int for some uses.)

Nico Schlömer
  • 53,797
  • 27
  • 201
  • 249
Anders Kaseorg
  • 3,657
  • 22
  • 35
  • 1
    I haven't checked them all out, but `gmpy2.isqrt` Is a C extension module, so it's not really doing it in Python. – martineau Dec 31 '18 at 08:20
  • 2
    @martineau Which is why I spent effort optimizing a pure Python answer too. Some readers will want pure Python, some will want the fastest thing regardless of technology, some will want something small and clear to copy and paste, some will want a solution they can just import, who knows—I’m just presenting the options in a way that allows them to be compared. – Anders Kaseorg Dec 31 '18 at 08:28
  • 3
    I'm the `gmpy2` maintainer and I have a couple comments on related functions in `gmpy2`. `gmpy2.is_square()`is usually the fastest method to determine if a number is a perfect square. It performs some very quick tests that can identify most non-perfect squares quickly and only calculates the square root if needed. `gmpy2.isqrt_rem()` will return the integer square root and the remainder. – casevh Jan 02 '19 at 05:47
  • 1
    You should credit Fredrik Johansson instead of ActiveState. The code there is just a Python implementation of his answer here: https://stackoverflow.com/a/1624602/4311651 – Wood Jun 04 '19 at 09:00
  • See the [math theory](https://math.stackexchange.com/a/34236/432081). – CopyPasteIt Jul 27 '21 at 13:04
  • Keeping `gmpy2` aside(since the system I am woking with wont have gmpy2) if `int(math.sqrt())` faster that `math.isqrt()` why bother about `isqrt`?. And I need for big numbers (~300 digits) – ishandutta2007 May 19 '22 at 01:33
  • 1
    @ishandutta2007 `int(math.sqrt(n))` will give the **wrong answer** for `n` greater than about `2**52` due to floating-point imprecision. For example, `int(math.sqrt(2**52 + 2**27)) == 6708865`, `math.isqrt(2**52 + 2**27) == 67108864`. – Anders Kaseorg May 19 '22 at 01:59
  • @AndersKaseorg Well not just wrong answer, I am getting error infact `OverflowError: int too large to convert to float` my code is `m = int(math.sqrt(n))` . I am on `Python 3.8.0` – ishandutta2007 May 19 '22 at 02:40
  • Thanks for your very interesting answer! Cool that you provided detailed timings of dozen of methods. Please put a look at [my answer](https://stackoverflow.com/a/73843172/941531), just posted it now. In it I created C++ code for several popular methods, to compare timings in pure C++. Also created Cython wrapper over C++ so that you can use all these C++ methods within Python code. I was very inspired by your answer and included your optimized code as one of my C++ functions, rewrote your code in pure C++. You edit my Cython code if you want to add more C++ methods, or tell me what to add. – Arty Sep 25 '22 at 08:56
22

Sorry for the very late response; I just stumbled onto this page. In case anyone visits this page in the future, the python module gmpy2 is designed to work with very large inputs, and includes among other things an integer square root function.

Example:

>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)

Granted, everything will have the "mpz" tag, but mpz's are compatible with int's:

>>> gmpy2.mpz(3)*4
mpz(12)

>>> int(gmpy2.mpz(12))
12

See my other answer for a discussion of this method's performance relative to some other answers to this question.

Download: https://code.google.com/p/gmpy/

Community
  • 1
  • 1
mathmandan
  • 580
  • 10
  • 17
8

Long-hand square root algorithm

It turns out that there is an algorithm for computing square roots that you can compute by hand, something like long-division. Each iteration of the algorithm produces exactly one digit of the resulting square root while consuming two digits of the number whose square root you seek. While the "long hand" version of the algorithm is specified in decimal, it works in any base, with binary being simplest to implement and perhaps the fastest to execute (depending on the underlying bignum representation).

Because this algorithm operates on numbers digit-by-digit, it produces exact results for arbitrarily large perfect squares, and for non-perfect-squares, can produce as many digits of precision (to the right of the decimal place) as desired.

There are two nice writeups on the "Dr. Math" site that explain the algorithm:

And here's an implementation in Python:

def exact_sqrt(x):
    """Calculate the square root of an arbitrarily large integer. 
 
    The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
    a is the largest integer such that a**2 <= x, and r is the "remainder".  If
    x is a perfect square, then r will be zero.
 
    The algorithm used is the "long-hand square root" algorithm, as described at
    http://mathforum.org/library/drmath/view/52656.html
 
    Tobin Fricke 2014-04-23
    Max Planck Institute for Gravitational Physics
    Hannover, Germany
    """
    
    N = 0   # Problem so far
    a = 0   # Solution so far
    
    # We'll process the number two bits at a time, starting at the MSB
    L = x.bit_length()
    L += (L % 2)          # Round up to the next even number
    
    for i in xrange(L, -1, -1):
        
        # Get the next group of two bits
        n = (x >> (2*i)) & 0b11
        
        # Check whether we can reduce the remainder
        if ((N - a*a) << 2) + n >= (a<<2) + 1:
            b = 1
        else:
            b = 0
        
        a = (a << 1) | b   # Concatenate the next bit of the solution
        N = (N << 2) | n   # Concatenate the next bit of the problem
    
    return (a, N-a*a)

You could easily modify this function to conduct additional iterations to calculate the fractional part of the square root. I was most interested in computing roots of large perfect squares.

I'm not sure how this compares to the "integer Newton's method" algorithm. I suspect that Newton's method is faster, since it can in principle generate multiple bits of the solution in one iteration, while the "long hand" algorithm generates exactly one bit of the solution per iteration.

Source repo: https://gist.github.com/tobin/11233492

Glorfindel
  • 21,988
  • 13
  • 81
  • 109
nibot
  • 14,428
  • 8
  • 54
  • 58
  • 6
    I rewrote your solution for Python3 and made it ~2.4 times faster! The single biggest optimization was re-writing the range function to do half as many iterations, but I massaged every line to squeeze out whatever performance I could. Please check out my version [here](https://gist.github.com/castle-bravo/e841684d6bad8e0598e31862a7afcfc7) and have my gratitude for providing a great starting point. – castle-bravo Sep 26 '16 at 00:39
8

Here's a very straightforward implementation:

def i_sqrt(n):
    i = n.bit_length() >> 1    # i = floor( (1 + floor(log_2(n))) / 2 )
    m = 1 << i    # m = 2^i
    #
    # Fact: (2^(i + 1))^2 > n, so m has at least as many bits 
    # as the floor of the square root of n.
    #
    # Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
    # >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
    #
    while m*m > n:
        m >>= 1
        i -= 1
    for k in xrange(i-1, -1, -1):
        x = m | (1 << k)
        if x*x <= n:
            m = x
    return m

This is just a binary search. Initialize the value m to be the largest power of 2 that does not exceed the square root, then check whether each smaller bit can be set while keeping the result no larger than the square root. (Check the bits one at a time, in descending order.)

For reasonably large values of n (say, around 10**6000, or around 20000 bits), this seems to be:

All of these approaches succeed on inputs of this size, but on my machine, this function takes around 1.5 seconds, while @Nibot's takes about 0.9 seconds, @user448810's takes around 19 seconds, and the gmpy2 built-in method takes less than a millisecond(!). Example:

>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True

This function can be generalized easily, though it's not quite as nice because I don't have quite as precise of an initial guess for m:

def i_root(num, root, report_exactness = True):
    i = num.bit_length() / root
    m = 1 << i
    while m ** root < num:
        m <<= 1
        i += 1
    while m ** root > num:
        m >>= 1
        i -= 1
    for k in xrange(i-1, -1, -1):
        x = m | (1 << k)
        if x ** root <= num:
            m = x
    if report_exactness:
        return m, m ** root == num
    return m

However, note that gmpy2 also has an i_root method.

In fact this method could be adapted and applied to any (nonnegative, increasing) function f to determine an "integer inverse of f". However, to choose an efficient initial value of m you'd still want to know something about f.

Edit: Thanks to @Greggo for pointing out that the i_sqrt function can be rewritten to avoid using any multiplications. This yields an impressive performance boost!

def improved_i_sqrt(n):
    assert n >= 0
    if n == 0:
        return 0
    i = n.bit_length() >> 1    # i = floor( (1 + floor(log_2(n))) / 2 )
    m = 1 << i    # m = 2^i
    #
    # Fact: (2^(i + 1))^2 > n, so m has at least as many bits
    # as the floor of the square root of n.
    #
    # Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
    # >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
    #
    while (m << i) > n: # (m<<i) = m*(2^i) = m*m
        m >>= 1
        i -= 1
    d = n - (m << i) # d = n-m^2
    for k in xrange(i-1, -1, -1):
        j = 1 << k
        new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
        if new_diff >= 0:
            d = new_diff
            m |= j
    return m

Note that by construction, the kth bit of m << 1 is not set, so bitwise-or may be used to implement the addition of (m<<1) + (1<<k). Ultimately I have (2*m*(2**k) + 2**(2*k)) written as (((m<<1) | (1<<k)) << k), so it's three shifts and one bitwise-or (followed by a subtraction to get new_diff). Maybe there is still a more efficient way to get this? Regardless, it's far better than multiplying m*m! Compare with above:

>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True
Community
  • 1
  • 1
mathmandan
  • 580
  • 10
  • 17
  • 2
    There is a way to eliminate all multiplies in the square root op, basically you keep track of the difference between n and m**2, and adjust that difference downwards whenever m is increased. If you look up the source code for a software version of sqrt e.g. http://www.netlib.org/fdlibm/e_sqrt.c you will find this method in use (and that one has a good explanation in the comments). – greggo Jul 06 '15 at 16:30
  • @greggo Excellent! I have posted an improved (multiplication-free) version of the integer square root function--let me know if you have further suggestions. Thanks so much for your help! – mathmandan Jul 06 '15 at 23:40
6

One option would be to use the decimal module, and do it in sufficiently-precise floats:

import decimal

def isqrt(n):
    nd = decimal.Decimal(n)
    with decimal.localcontext() as ctx:
        ctx.prec = n.bit_length()
        i = int(nd.sqrt())
    if i**2 != n:
        raise ValueError('input was not a perfect square')
    return i

which I think should work:

>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
  File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
    isqrt(11**1000+1)
  File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
    raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square
DSM
  • 342,061
  • 65
  • 592
  • 494
  • 2
    You can improve on this slightly, by letting `decimal` do the work of raising exceptions. Just add `ctx.traps = [decimal.Rounded, decimal.Inexact]` to set the trap conditions and it will raise exceptions if either occurs; no need to test and `raise` yourself. You can also improve performance a great deal by changing `ctx.prec = n.bit_length()` to something like `ctx.prec = int(math.ceil(math.log10(n))) + 1` which dramatically reduces precision, but *should* still always provide *enough* precision. Of course, this assumes you don't have `math.isqrt` and can't update to Python 3.8+ to get it. – ShadowRanger Dec 27 '19 at 21:11
  • @ShadowRanger As noted in some other comments the math module is not an option here because it cannot handle arbitrary large integers – miracle173 Jun 07 '21 at 20:41
  • 1
    @miracle173: `math.isqrt` can. The other comments (which predate the *existence* of `math.isqrt`) are referring to `math.sqrt` (or equivalently, just doing `** 0.5`, both of which produce `float` outputs. `math.isqrt` is pure integer to integer, and handles arbitrarily large integers just fine. – ShadowRanger Jun 08 '21 at 00:44
  • @ShadowRanger Yes, you are right. math.isqrt is able to handle arbitrarily large integers – miracle173 Jun 08 '21 at 03:03
4

Python's default math library has an integer square root function:

math.isqrt(n)

Return the integer square root of the nonnegative integer n. This is the floor of the exact square root of n, or equivalently the greatest integer a such that a² ≤ n.

iacob
  • 20,084
  • 6
  • 92
  • 119
3

Seems like you could check like this:

if int(math.sqrt(n))**2 == n:
    print n, 'is a perfect square'

Update:

As you pointed out the above fails for large values of n. For those the following looks promising, which is an adaptation of the example C code, by Martin Guy @ UKC, June 1985, for the relatively simple looking binary numeral digit-by-digit calculation method mentioned in the Wikipedia article Methods of computing square roots:

from math import ceil, log

def isqrt(n):
    res = 0
    bit = 4**int(ceil(log(n, 4))) if n else 0  # smallest power of 4 >= the argument
    while bit:
        if n >= res + bit:
            n -= res + bit
            res = (res >> 1) + bit
        else:
            res >>= 1
        bit >>= 2
    return res

if __name__ == '__main__':
    from math import sqrt  # for comparison purposes

    for i in range(17)+[2**53, (10**100+1)**2]:
        is_perfect_sq = isqrt(i)**2 == i
        print '{:21,d}:  math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
            i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')

Output:

                    0:  math.sqrt=           0, isqrt=         0 (perfect square)
                    1:  math.sqrt=           1, isqrt=         1 (perfect square)
                    2:  math.sqrt=    1.414214, isqrt=         1
                    3:  math.sqrt=    1.732051, isqrt=         1
                    4:  math.sqrt=           2, isqrt=         2 (perfect square)
                    5:  math.sqrt=    2.236068, isqrt=         2
                    6:  math.sqrt=     2.44949, isqrt=         2
                    7:  math.sqrt=    2.645751, isqrt=         2
                    8:  math.sqrt=    2.828427, isqrt=         2
                    9:  math.sqrt=           3, isqrt=         3 (perfect square)
                   10:  math.sqrt=    3.162278, isqrt=         3
                   11:  math.sqrt=    3.316625, isqrt=         3
                   12:  math.sqrt=    3.464102, isqrt=         3
                   13:  math.sqrt=    3.605551, isqrt=         3
                   14:  math.sqrt=    3.741657, isqrt=         3
                   15:  math.sqrt=    3.872983, isqrt=         3
                   16:  math.sqrt=           4, isqrt=         4 (perfect square)
9,007,199,254,740,992:  math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001:  math.sqrt=      1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)
martineau
  • 119,623
  • 25
  • 170
  • 301
  • @wim: True...I believe the last update I made to my answer fixes that shortcoming and is therefore a very usable solution. – martineau Mar 14 '13 at 15:41
  • The `4 ** int(ceil(log(n, 4)))` construction still relies on floating-point math, and thus may fail for large inputs. Use `4 ** ((n - 1).bit_length() + 1 >> 1)` instead, or better, `1 << ((n - 1).bit_length() + 1 & -2)`. – Anders Kaseorg Dec 31 '18 at 03:15
  • @AndersKaseorg: I'll look into it. FWIW, `bit_length()` doesn't exist in Python 2, which is the version that was used to develop the code currently in the answer. – martineau Dec 31 '18 at 08:28
  • 2
    It exists in 2.7. (Not saying I don’t believe that you weren’t using 2.7 back then—but I hope I’m only being slightly optimistic to say that nobody’s developing for pre-2.7 now.) – Anders Kaseorg Dec 31 '18 at 08:31
  • @AndersKaseorg: Hard to say exactly what I was using/thinking almost 6 years ago, but I often didn't use features that only existed in 2.7 to make the code work on more Pythons. That said, I've dealt with lack of `bit_length()` before and think I have the pure Python equivalent laying around somewhere—but it would surely be slower. – martineau Dec 31 '18 at 08:46
  • 2
    `(n - 1).bit_length()` ↦ `len(bin(n - 1)) - 2` was the typical workaround from that era. – Anders Kaseorg Dec 31 '18 at 08:59
3

The script below extracts integer square roots. It uses no divisions, only bitshifts, so it is quite fast. It uses Newton's method on the inverse square root, a technique made famous by Quake III Arena as mentioned in the Wikipedia article, Fast inverse square root.

The strategy of the algorithm to compute s = sqrt(Y) is as follows.

  1. Reduce the argument Y to y in the range [1/4, 1), i.e., y = Y/B, with 1/4 <= y < 1, where B is an even power of 2, so B = 2**(2*k) for some integer k. We want to find X, where x = X/B, and x = 1 / sqrt(y).
  2. Determine a first approximation to X using a quadratic minimax polynomial.
  3. Refine X using Newton's method.
  4. Calculate s = X*Y/(2**(3*k)).

We don't actually create fractions or perform any divisions. All the arithmetic is done with integers, and we use bit shifting to divide by various powers of B.

Range reduction lets us find a good initial approximation to feed to Newton's method. Here's a version of the 2nd degree minimax polynomial approximation to the inverse square root in the interval [1/4, 1):

Minimax poly for 1/sqrt(x)

(Sorry, I've reversed the meaning of x & y here, to conform to the usual conventions). The maximum error of this approximation is around 0.0355 ~= 1/28. Here's a graph showing the error:

Minimax poly error graph

Using this poly, our initial x starts with at least 4 or 5 bits of precision. Each round of Newton's method doubles the precision, so it doesn't take many rounds to get thousands of bits, if we want them.


""" Integer square root

    Uses no divisions, only shifts
    "Quake" style algorithm,
    i.e., Newton's method for 1 / sqrt(y)
    Uses a quadratic minimax polynomial for the first approximation

    Written by PM 2Ring 2022.01.23
"""

def int_sqrt(y):
    if y < 0:
        raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
    if y < 2:
        return y

    # print("\n*", y, "*")
    # Range reduction.
    # Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
    j = y.bit_length()
    # Round k*2 up to the next even number
    k2 = j + (j & 1)
    # k and some useful multiples
    k = k2 >> 1
    k3 = k2 + k
    k6 = k3 << 1
    kd = k6 + 1
    # b cubed
    b3 = 1 << k6

    # Minimax approximation: x/b ~= 1 / sqrt(y/b)
    x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
    # print("   ", x, h)

    # Newton's method for 1 / sqrt(y/b)
    epsilon = 1 << k
    for i in range(1, 99):
        dx = x * (b3 - y * x * x) >> kd
        x += dx
        # print(f" {i}: {x} {dx}")
        if abs(dx) <= epsilon:
            break

    # s == sqrt(y)
    s = x * y >> k3
    # Adjust if too low
    ss = s + 1
    return ss if ss * ss <= y else s

def test(lo, hi, step=1):
    for y in range(lo, hi, step):
        s = int_sqrt(y)
        ss = s + 1
        s2, ss2 = s * s, ss * ss
        assert s2 <= y < ss2, (y, s2, ss2)
    print("ok")

test(0, 100000, 1)

This code is certainly slower than math.isqrt and decimal.Decimal.sqrt. Its purpose is simply to illustrate the algorithm. It would be interesting to see how fast it would be if it were implemented in C...


Here's a live version, running on the SageMathCell server. Set hi <= 0 to calculate and display the results for a single value set in lo. You can put expressions in the input boxes, eg set hi to 0 and lo to 2 * 10**100 to get sqrt(2) * 10**50.

PM 2Ring
  • 54,345
  • 6
  • 82
  • 182
1

Inspired by all answers, decided to implement in pure C++ several best methods from these answers. As everybody knows C++ is always faster than Python.

To glue C++ and Python I used Cython. It allows to make out of C++ a Python module and then call C++ functions directly from Python functions.

Also as complementary I provided not only Python-adopted code, but pure C++ with tests too.

Here are timings from pure C++ tests:

Test           'GMP', bits     64, time  0.000001 sec
Test 'AndersKaseorg', bits     64, time  0.000003 sec
Test    'Babylonian', bits     64, time  0.000006 sec
Test  'ChordTangent', bits     64, time  0.000018 sec

Test           'GMP', bits  50000, time  0.000118 sec
Test 'AndersKaseorg', bits  50000, time  0.002777 sec
Test    'Babylonian', bits  50000, time  0.003062 sec
Test  'ChordTangent', bits  50000, time  0.009120 sec

and same C++ functions but as adopted Python module have timings:

Bits 50000
         math.isqrt:   2.819 ms
        gmpy2.isqrt:   0.166 ms
          ISqrt_GMP:   0.252 ms
ISqrt_AndersKaseorg:   3.338 ms
   ISqrt_Babylonian:   3.756 ms
 ISqrt_ChordTangent:  10.564 ms

My Cython-C++ is nice in a sence as a framework for those people who want to write and test his own C++ method from Python directly.

As you noticed in above timings as example I used following methods:

  1. math.isqrt, implementation from standard library.

  2. gmpy2.isqrt, GMPY2 library's implementation.

  3. ISqrt_GMP - same as GMPY2, but using my Cython module, there I use C++ GMP library (<gmpxx.h>) directly.

  4. ISqrt_AndersKaseorg, code taken from answer of @AndersKaseorg.

  5. ISqrt_Babylonian, method taken from Wikipedia article, so-called Babylonian method. My own implementation as I understand it.

  6. ISqrt_ChordTangent, it is my own method that I called Chord-Tangent, because it uses chord and tangent line to iteratively shorten interval of search. This method is described in moderate details in my other article. This method is nice because it searches not only square root, but also K-th root for any K. I drew a small picture showing details of this algorithm.

Regarding compiling C++/Cython code, I used GMP library. You need to install it first, under Linux it is easy through sudo apt install libgmp-dev.

Under Windows easiest is to install really great program VCPKG, this is software Package Manager, similar to APT in Linux. VCPKG compiles all packages from sources using Visual Studio (don't forget to install Community version of Visual Studio). After installing VCPKG you can install GMP by vcpkg install gmp. Also you may install MPIR, this is alternative fork of GMP, you can install it through vcpkg install mpir.

After GMP is installed under Windows please edit my Python code and replace path to include directory and library file. VCPKG at the end of installation should show you path to ZIP file with GMP library, there are .lib and .h files.

You may notice in Python code that I also designed special handy cython_compile() function that I use to compile any C++ code into Python module. This function is really good as it allows for you to easily plug-in any C++ code into Python, this can be reused many times.

If you have any questions or suggestions, or something doesn't work on your PC, please write in comments.

Below first I show code in Python, afterwards in C++. See Try it online! link above C++ code to run code online on GodBolt servers. Both code snippets I fully runnable from scratch as they are, nothing needs to be edited in them.

def cython_compile(srcs):
    import json, hashlib, os, glob, importlib, sys, shutil, tempfile
    srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
    pdir = 'cyimp'
    
    if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
        class ChDir:
            def __init__(self, newd):
                self.newd = newd
            def __enter__(self):
                self.curd = os.getcwd()
                os.chdir(self.newd)
                return self
            def __exit__(self, ext, exv, tb):
                os.chdir(self.curd)

        os.makedirs(pdir, exist_ok = True)
        with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
            os.makedirs(pdir, exist_ok = True)
                
            for k, v in srcs.items():
                with open(f'cys{srch}_{k}', 'wb') as f:
                    f.write(v.replace('{srch}', srch).encode('utf-8'))

            import numpy as np
            from setuptools import setup, Extension
            from Cython.Build import cythonize

            sys.argv += ['build_ext', '--inplace']
            setup(
                ext_modules = cythonize(
                    Extension(
                        f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
                        depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
                        extra_compile_args = ['/O2', '/std:c++latest',
                            '/ID:/dev/_3party/vcpkg_bin/gmp/include/',
                        ],
                    ),
                    compiler_directives = {'language_level': 3, 'embedsignature': True},
                    annotate = True,
                ),
                include_dirs = [np.get_include()],
            )
            del sys.argv[-2:]
            for f in glob.glob(f'{pdir}/cy{srch}*'):
                shutil.copy(f, f'./../')

    print('Cython module:', f'cy{srch}')
    return importlib.import_module(f'{pdir}.cy{srch}')

def cython_import():
    srcs = {
        'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>

#include <gmpxx.h>

#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")

#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }

using u32 = uint32_t;
using u64 = uint64_t;

template <typename T>
size_t BitLen(T n) {
    if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
        return mpz_sizeinbase(n.get_mpz_t(), 2);
    else {
        size_t cnt = 0;
        while (n >= (1ULL << 32)) {
            cnt += 32;
            n >>= 32;
        }
        while (n >= (1 << 8)) {
            cnt += 8;
            n >>= 8;
        }
        while (n) {
            ++cnt;
            n >>= 1;
        }
        return cnt;
    }
}

template <typename T>
T ISqrt_Babylonian(T const & y) {
    // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
    if (y <= 1)
        return y;
    T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
    while (true) {
        size_t constexpr loops = 3;
        for (size_t i = 0; i < loops; ++i) {
            if (i + 1 >= loops)
                a = x;
            b = y;
            b /= x;
            x += b;
            x >>= 1;
        }
        if (b < a)
            std::swap(a, b);
        if (b - a > limit)
            continue;
        ++b;
        for (size_t i = 0; a <= b; ++a, ++i)
            if (a * a > y) {
                if (i == 0)
                    break;
                else
                    return a - 1;
            }
        ASSERT(false);
    }
}

template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
    // https://stackoverflow.com/a/53983683/941531
    if (n > 0) {
        T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
        while (true) {
            y = (x + n / x) >> 1;
            if (y >= x)
                return x;
            x = y;
        }
    } else if (n == 0)
        return 0;
    else
        ASSERT_MSG(false, "square root not defined for negative numbers");
}

template <typename T>
T ISqrt_GMP(T const & y) {
    // https://gmplib.org/manual/Integer-Roots
    mpz_class r, n;
    bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
    if constexpr(is_mpz)
        n = y;
    else {
        static_assert(sizeof(T) <= 8);
        n = u32(y >> 32);
        n <<= 32;
        n |= u32(y);
    }
    mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
    if constexpr(is_mpz)
        return r;
    else
        return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}

template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
    // https://i.stack.imgur.com/et9O0.jpg
    if (n <= 1)
        return n;
    auto KthPow = [&](auto const & x){
        T y = x * x;
        for (size_t i = 2; i < k; ++i)
            y *= x;
        return y;
    };
    auto KthPowDer = [&](auto const & x){
        T y = x * u32(k);
        for (size_t i = 1; i + 1 < k; ++i)
            y *= x;
        return y;
    };
    size_t root_bit_len = (BitLen(n) + k - 1) / k;
    T   hi = T(1) << root_bit_len,
        x_begin = hi >> 1, x_end = hi,
        y_begin = KthPow(x_begin), y_end = KthPow(x_end),
        x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
    for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
        if (x_end <= x_begin + 2)
            break;
        if constexpr(0) { // Do Binary Search step if needed
            x_mid = (x_begin + x_end) >> 1;
            y_mid = KthPow(x_mid);
            if (y_mid > n) {
                x_end = x_mid; y_end = y_mid;
            } else {
                x_begin = x_mid; y_begin = y_mid;
            }
        }
        // (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
        x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
        y_n = KthPow(x_n);
        tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
        chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
        //ASSERT(chord_x <= tangent_x);
        x_begin = chord_x; x_end = tangent_x;
        y_begin = KthPow(x_begin); y_end = KthPow(x_end);
        //ASSERT(y_begin <= n);
        //ASSERT(y_end > n);
    }
    for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
        if (x_begin * x_begin > n) {
            if (i == 0)
                break;
            else
                return x_begin - 1;
        }
    ASSERT(false);
    return 0;
}

mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
    mpz_class r;
    mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
    return r;
}

void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
    uint64_t cnt_before = *cnt;
    size_t cnt_res = 0;
    mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
    ASSERT(cnt_res <= cnt_before);
    std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
    *cnt = cnt_res;
}

void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
        """,
        'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *

cdef extern from "cys{srch}_lib.h" nogil:
    void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
    void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
    void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
    void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);

@cython.boundscheck(False)
@cython.wraparound(False)
def ISqrt(method, n):
    mask64 = (1 << 64) - 1
    def ToLimbs():
        return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
        
        words = (n.bit_length() + 63) // 64
        t = n
        r = np.zeros((words,), dtype = np.uint64)
        for i in range(words):
            r[i] = np.uint64(t & mask64)
            t >>= 64
        return r
    def FromLimbs(x):
        return int.from_bytes(x.tobytes(), 'little')
        
        n = 0
        for i in range(x.shape[0]):
            n |= int(x[i]) << (i * 64)
        return n
    n = ToLimbs()
    cdef uint64_t[:] cn = n
    cdef uint64_t ccnt = len(n)
    cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
    with nogil:
        (ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
            <uint64_t *>&cn[0], <uint64_t *>&ccnt
        )
    return FromLimbs(n[:ccnt])
        """,
    }
    return cython_compile(srcs)

def main():
    import math, gmpy2, timeit, random
    mod = cython_import()
    fs = [
        ('math.isqrt', math.isqrt),
        ('gmpy2.isqrt', gmpy2.isqrt),
        ('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
        ('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
        ('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
        ('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
    ]
    times = [0] * len(fs)
    ntests = 1 << 6
    bits = 50000
    for i in range(ntests):
        n = random.randrange(1 << (bits - 1), 1 << bits)
        ref = None
        for j, (fn, f) in enumerate(fs):
            timeit_cnt = 3
            tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
            times[j] += tim
            x = f(n)
            if j == 0:
                ref = x
            else:
                assert x == ref, (fn, ref, x)
    print('Bits', bits)
    print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))

if __name__ == '__main__':
    main()

and C++:

Try it online!

#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>

#include <gmpxx.h>

#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }

using u32 = uint32_t;
using u64 = uint64_t;

template <typename T>
size_t BitLen(T n) {
    if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
        return mpz_sizeinbase(n.get_mpz_t(), 2);
    else {
        size_t cnt = 0;
        while (n >= (1ULL << 32)) {
            cnt += 32;
            n >>= 32;
        }
        while (n >= (1 << 8)) {
            cnt += 8;
            n >>= 8;
        }
        while (n) {
            ++cnt;
            n >>= 1;
        }
        return cnt;
    }
}

template <typename T>
T ISqrt_Babylonian(T const & y) {
    // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
    if (y <= 1)
        return y;
    T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
    while (true) {
        size_t constexpr loops = 3;
        for (size_t i = 0; i < loops; ++i) {
            if (i + 1 >= loops)
                a = x;
            b = y;
            b /= x;
            x += b;
            x >>= 1;
        }
        if (b < a)
            std::swap(a, b);
        if (b - a > limit)
            continue;
        ++b;
        for (size_t i = 0; a <= b; ++a, ++i)
            if (a * a > y) {
                if (i == 0)
                    break;
                else
                    return a - 1;
            }
        ASSERT(false);
    }
}

template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
    // https://stackoverflow.com/a/53983683/941531
    if (n > 0) {
        T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
        while (true) {
            y = (x + n / x) >> 1;
            if (y >= x)
                return x;
            x = y;
        }
    } else if (n == 0)
        return 0;
    else
        ASSERT_MSG(false, "square root not defined for negative numbers");
}

template <typename T>
T ISqrt_GMP(T const & y) {
    // https://gmplib.org/manual/Integer-Roots
    mpz_class r, n;
    bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
    if constexpr(is_mpz)
        n = y;
    else {
        static_assert(sizeof(T) <= 8);
        n = u32(y >> 32);
        n <<= 32;
        n |= u32(y);
    }
    mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
    if constexpr(is_mpz)
        return r;
    else
        return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}

template <typename T>
std::string IntToStr(T n) {
    if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
        return n.get_str();
    else {
        std::ostringstream ss;
        ss << n;
        return ss.str();
    }
}

template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
    // https://i.stack.imgur.com/et9O0.jpg
    if (n <= 1)
        return n;
    auto KthPow = [&](auto const & x){
        T y = x * x;
        for (size_t i = 2; i < k; ++i)
            y *= x;
        return y;
    };
    auto KthPowDer = [&](auto const & x){
        T y = x * u32(k);
        for (size_t i = 1; i + 1 < k; ++i)
            y *= x;
        return y;
    };
    size_t root_bit_len = (BitLen(n) + k - 1) / k;
    T   hi = T(1) << root_bit_len,
        x_begin = hi >> 1, x_end = hi,
        y_begin = KthPow(x_begin), y_end = KthPow(x_end),
        x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
    for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
        //std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
        if (x_end <= x_begin + 2)
            break;
        if constexpr(0) { // Do Binary Search step if needed
            x_mid = (x_begin + x_end) >> 1;
            y_mid = KthPow(x_mid);
            if (y_mid > n) {
                x_end = x_mid; y_end = y_mid;
            } else {
                x_begin = x_mid; y_begin = y_mid;
            }
        }
        // (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
        x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
        y_n = KthPow(x_n);
        tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
        chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
        //ASSERT(chord_x <= tangent_x);
        x_begin = chord_x; x_end = tangent_x;
        y_begin = KthPow(x_begin); y_end = KthPow(x_end);
        //ASSERT(y_begin <= n);
        //ASSERT(y_end > n);
    }
    for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
        if (x_begin * x_begin > n) {
            if (i == 0)
                break;
            else
                return x_begin - 1;
        }
    ASSERT(false);
    return 0;
}

mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
    mpz_class r;
    mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
    return r;
}

void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
    uint64_t cnt_before = *cnt;
    size_t cnt_res = 0;
    mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
    ASSERT(cnt_res <= cnt_before);
    std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
    *cnt = cnt_res;
}

void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}

void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
    ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}

// Testing

#include <chrono>
#include <random>
#include <vector>
#include <iomanip>

inline double Time() {
    static auto const gtb = std::chrono::high_resolution_clock::now();
    return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
        .count();
}

template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
    std::mt19937_64 rng{123};
    std::vector<T> nums;
    for (size_t i = 0; i < ntests; ++i) {
        T n = 0;
        for (size_t j = 0; j < bits; j += 32) {
            size_t const cbits = std::min<size_t>(32, bits - j);
            n <<= cbits;
            n ^= u32(rng()) >> (32 - cbits);
        }
        nums.push_back(n);
    }
    auto tim = Time();
    for (auto & n: nums)
        n = f(n);
    tim = Time() - tim;
    std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
        << ", bits " << std::setw(6) << bits << ", time "
        << std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
    return nums;
}

void Test() {
    auto f = [](auto ty, size_t bits, size_t ntests){
        using T = std::decay_t<decltype(ty)>;
        auto tim = Time();
        auto a = Test0<T>("GMP",           bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
        auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
        ASSERT(b == a);
        auto c = Test0<T>("Babylonian",    bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
        ASSERT(c == a);
        auto d = Test0<T>("ChordTangent",  bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
        ASSERT(d == a);
        std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
    };
    for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
        f(u64(), p.first, p.second);
    for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
        f(mpz_class(), p.first, p.second);
}

int main() {
    try {
        Test();
        return 0;
    } catch (std::exception const & ex) {
        std::cout << "Exception: " << ex.what() << std::endl;
        return -1;
    }
}
Arty
  • 14,883
  • 6
  • 36
  • 69
0

Your function fails for large inputs:

In [26]: isqrt((10**100+1)**2)

ValueError: input was not a perfect square

There is a recipe on the ActiveState site which should hopefully be more reliable since it uses integer maths only. It is based on an earlier StackOverflow question: Writing your own square root function

Community
  • 1
  • 1
NPE
  • 486,780
  • 108
  • 951
  • 1,012
-3

Floats cannot be precisely represented on computers. You can test for a desired proximity setting epsilon to a small value within the accuracy of python's floats.

def isqrt(n):
    epsilon = .00000000001
    i = int(n**.5 + 0.5)
    if abs(i**2 - n) < epsilon:
        return i
    raise ValueError('input was not a perfect square')
Eugene Yarmash
  • 142,882
  • 41
  • 325
  • 378
Octipi
  • 835
  • 7
  • 12
  • This too seems to fail for larger values of n. Newton's method looks promising or the decimal.Decimal solution. – Octipi Mar 13 '13 at 17:14
-4

I have compared the different methods given here with a loop:

for i in range (1000000): # 700 msec
    r=int(123456781234567**0.5+0.5)
    if r**2==123456781234567:rr=r
    else:rr=-1

finding that this one is fastest and need no math-import. Very long might fail, but look at this

15241576832799734552675677489**0.5 = 123456781234567.0
ted
  • 13,596
  • 9
  • 65
  • 107
  • 1
    please look at this 1000000000000000000000053646891100000000000000011928482406771 (=nextprime(10^29+5342347)*nextprime(10^31+2232788)) from [A ONE LINE FACTORING ALGORITHM](http://wrap.warwick.ac.uk/54707/1/WRAP_Hart_S1446788712000146a.pdf) – miracle173 Jun 07 '21 at 20:48
-5

Try this condition (no additional computation):

def isqrt(n):
  i = math.sqrt(n)
  if i != int(i):
    raise ValueError('input was not a perfect square')  
  return i

If you need it to return an int (not a float with a trailing zero) then either assign a 2nd variable or compute int(i) twice.

Eugene Yarmash
  • 142,882
  • 41
  • 325
  • 378
javex
  • 7,198
  • 7
  • 41
  • 60
  • 1
    An alternative can be `if not i.is_integer()`. Anyway, this function fails for big inputs, where the number cannot be represented as float(and probably even before that). – Bakuriu Mar 13 '13 at 16:28
  • 5
    Try calling this function with `(10**10)**2-1` and see it mistakenly think that the argument is a perfect square. – NPE Mar 13 '13 at 16:28