Inspired by one of the SO answers, np.random.dirichlet
or np.random.gamma
could be used, cautiously, in this regard. In one of my tests:
spec_len = 15
result = np.random.dirichlet(np.ones(spec_len), size=1)
# [[0.04798813 0.04001169 0.03344301 0.03314941 0.02594012 0.00738061
# 0.05712809 0.20175982 0.0605601 0.17495591 0.01114262 0.12018851
# 0.08900246 0.05989448 0.03745501]]
# SUM --> 0.9999999999999999
shape, scale = 0.5, 0.25
result = np.random.gamma(shape, scale, spec_len)
# [5.44188272e-03 2.81687195e-01 1.68385119e-01 1.77502131e-03
# 1.21338381e-03 1.12168743e-02 5.98830384e-02 5.69830641e-02
# 3.13285820e-02 3.92879720e-01 1.69169125e-02 3.99294001e-07
# 1.24306290e-01 1.06084121e-02 4.63536093e-02]
# SUM --> 1.208979502912514
Both will get results near 1 (np.random.gamma
is more flexible which can be used for some other magnitudes; It can be controlled by shape
and scale
).
These methods can be used in loop to ensure the sum >= 1
; The proposed ways will be satisfying in terms of time. E.g.:
sum = 0
while sum <= 1: # sum < 1
result = np.random.dirichlet(np.ones(spec_len), size=1)
sum = np.sum(result)
# SUM --> 1.0000000000000002
The fastest way that I know, if iteration is not matter, is using Numba decorators on your function. its decorator can be used as:
@nb.njit('float64(int64)', parallel=True) # without parallelization: @nb.njit() --> with signatures: @nb.njit('float64(int64)')
def funct(n): #10_000
media = 0
for _ in nb.prange(n): # change to 'for _ in range(n):' if not parallelized
result = 0
count = 0
while result < 1:
x = random()
result += x
count += 1
if result >= 1:
break
media += count
return media / n
Numba is the fastest way --> ~ 16 times faster
Benchmarks (Colab CPU)
The results are for n=10000
, n=100000
, n=1000000
, n=10000000
respectively:
OP code:
100 loops, best of 5: 6.45 ms per loop
100 loops, best of 5: 64.9 ms per loop
100 loops, best of 5: 655 ms per loop
n=10000000 --> not tested estimated ~ 6 s per loop
NumPy dirichlet without while: --> SUM must be checked to ensure >= 1 (~ x12)
100 loops, best of 5: 528 µs per loop
100 loops, best of 5: 5.59 ms per loop
100 loops, best of 5: 55.7 ms per loop
100 loops, best of 5: 550 ms per loop
NumPy dirichlet wit while: (~ x6)
100 loops, best of 5: 1.1 ms per loop
100 loops, best of 5: 11.4 ms per loop
100 loops, best of 5: 107 ms per loop
100 loops, best of 5: 1.07 s per loop
Numba nopython(njit) ~ nopython(njit) + signatures: (~ x7)
100 loops, best of 5: 899 µs per loop
100 loops, best of 5: 9.13 ms per loop
100 loops, best of 5: 91.5 ms per loop
100 loops, best of 5: 914 ms per loop
Numba nopython(njit) + signatures + parallel: <-- THE FASTEST (~ x16)
100 loops, best of 5: 404 µs per loop
100 loops, best of 5: 3.96 ms per loop
100 loops, best of 5: 40 ms per loop
100 loops, best of 5: 401 ms per loop