I have made my own Sieve of Eratosthenes implementation in NumPy. I am sure you all know it is for finding all primes below a number, so I won't explain anything further.
Code:
import numpy as np
def primes_sieve(n):
primes = np.ones(n+1, dtype=bool)
primes[:2] = False
primes[4::2] = False
for i in range(3, int(n**0.5)+1, 2):
if primes[i]:
primes[i*i::i] = False
return np.where(primes)[0]
As you can see I have already made some optimizations, first all primes are odd except for 2, so I set all multiples of 2 to False
and only brute-force odd numbers.
Second I only looped through numbers up to the floor of the square root, because all composite numbers after the square root would be eliminated by being a multiple of a prime number below the square root.
But it isn't optimal, because it loops through all odd numbers below the limit, and not all odd numbers are prime. And as the number grows larger, primes become more sparse, so there are lots of redundant iterations.
So if the list of candidates is changed dynamically, in such a way that composite numbers already identified wouldn't even ever be iterated upon, so that only prime numbers are looped through, there won't be any wasteful iterations, thus the algorithm would be optimal.
I have written a crude implementation of the optimized version:
def primes_sieve_opt(n):
primes = np.ones(n+1, dtype=bool)
primes[:2] = False
primes[4::2] = False
limit = int(n**0.5)+1
i = 2
while i < limit:
primes[i*i::i] = False
i += 1 + primes[i+1:].argmax()
return np.where(primes)[0]
But it is much slower than the unoptimized version:
In [92]: %timeit primes_sieve(65536)
271 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [102]: %timeit primes_sieve_opt(65536)
309 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
My idea is simple, by getting the next index of True
, I can ensure all primes are covered and only primes are processed.
However np.argmax
is slow in this regard. I Google searched "how to find the index of the next True value in NumPy array" (without quotes), and I found several StackOverflow questions that are slightly relevant but ultimately doesn't answer my question.
For example, numpy get index where value is true and Numpy first occurrence of value greater than existing value.
I am not trying to find all indexes where True
, and it is extremely stupid to do that, I need to find the next True
value, get its index and immediately stop looping, there are only bool
s.
How can I optimize this?
Edit
If anyone is interested, I have optimized my algorithm further:
import numba
import numpy as np
@numba.jit(nopython=True, parallel=True, fastmath=True, forceobj=False)
def prime_sieve(n: int) -> np.ndarray:
primes = np.full(n + 1, True)
primes[:2] = False
primes[4::2] = False
primes[9::6] = False
limit = int(n**0.5) + 1
for i in range(5, limit, 6):
if primes[i]:
primes[i * i :: 2 * i] = False
for i in range(7, limit, 6):
if primes[i]:
primes[i * i :: 2 * i] = False
return np.flatnonzero(primes)
I used numba
to speed things up. And since all primes except 2 and 3 are either 6k+1 or 6k-1, this makes things even faster.