In general, there are many ways to choose an integer with a custom distribution, but most of them take weights that are proportional to the given probabilities. If the weights are log probabilities instead, then a slightly different approach is needed. Perhaps the simplest algorithm for this is rejection sampling, described below and implemented in Python. In the following algorithm, the maximum log-probability is max
, and there are k
integers to choose from.
- Choose a uniform random integer
i
in [0, k
).
- Get the log-weight corresponding to
i
, then generate an exponential(1) random number, call it ex
.
- If
max
minus ex
is less than the log-weight, return i
. Otherwise, go to step 1.
The time complexity for rejection sampling is constant on average, especially if max
is set to equal the true maximum weight. On the other hand, the expected number of iterations per sample depends greatly on the shape of the distribution. See also Keith Schwarz's discussion on the "Fair Die/Biased Coin Loaded Die" algorithm.
Now, Python code for this algorithm follows.
import random
import math
def categ(c):
# Do a weighted choice of an item with the
# given log-probabilities.
cm=max(c) # Find max log probability
while True:
# Choose an item at random
x=random.randint(0,len(c)-1)
# Choose it with probability proportional
# to exp(c[x])
y=cm-random.expovariate(1)
# Alternatively: y=math.log(random.random())+cm
if y<c[x]:
return x
The code above generates one variate at a time and uses only Python's base modules, rather than NumPy. Another answer shows how rejection sampling can be implemented in NumPy by blocks of random variates at a time (demonstrated on a different random sampling task, though).
The so-called "Gumbel max trick", used above all in machine learning, can be used to sample from a distribution with unnormalized log probabilities. This involves—
- ("Gumbel") adding a separate Gumbel random variate to each log probability, namely −ln(−ln(U)) where U is a random variate greater than 0 and less than 1, then
- ("max") choosing the item corresponding to the highest log probability.
However, the time complexity for this algorithm is linear in the number of items.
The following code illustrates the Gumbel max trick:
import random
import math
def categ(c):
# Do a weighted choice of an item with the
# given log-probabilities, using the Gumbel max trick
return max([[c[i]-math.log(-math.log(random.random())),i] \
for i in range(len(c))])[1]
# Or:
# return max([[c[i]-math.log(random.expovariate(1)),i] \
# for i in range(len(c))])[1]