2

I'd like to evaluate np.random.dirichlet with large dimension as quickly as possible. More precisely, I'd like a function approximating the below by at least 10 times faster. Empirically, I observed that small-dimension-version of this function outputs one or two entries that have the order of 0.1, and every other entries are so small that they are immaterial. But this observation isn't based on any rigorous assessment. The approximation doesn't need to be so accurate, but I want something not too crude, as I'm using this noise for MCTS.

def g():
   np.random.dirichlet([0.03]*4840)

>>> timeit.timeit(g,number=1000)
0.35117408499991143
  • I think you're stuck if your parameters aren't integers. Generating gamma distributed values (and dirichelet is just a ratio of sums of gamma distributions) with non integer parameters is an iterative process – Daniel F Feb 24 '18 at 08:35
  • I suppose you could home brew a gamma generator with lower accuracy but it would probably be in c. – Daniel F Feb 24 '18 at 08:39
  • @DanielF As OP seems interested in symmetric distribution, i.e. involving only a single Gamma distribution they may get away with tabulating the ppf of that. Might even be fast enough in NumPy. – Paul Panzer Feb 24 '18 at 10:19
  • @PaulPanzer By tabulating, do you mean I should take samples of the gamma distribution at first, store them in an array and randomly sample one of them to use it? That sounds easier than what Daniel F proposed. – Math.StackExchange Feb 24 '18 at 10:32
  • @Math.StackExchange I've written it down in the answer below. – Paul Panzer Feb 24 '18 at 12:13

1 Answers1

2

Assuming your alpha is fixed over components and used for many iterations you could tabulate the ppf of the corresponding gamma distribution. This is probably available as scipy.stats.gamma.ppf but we can also use scipy.special.gammaincinv. This function seems rather slow, so this is a siginificant upfront investment.

Here is a crude implementation of the general idea:

import numpy as np
from scipy import special

class symm_dirichlet:
    def __init__(self, alpha, resolution=2**16):
        self.alpha = alpha
        self.resolution = resolution
        self.range, delta = np.linspace(0, 1, resolution,
                                        endpoint=False, retstep=True)
        self.range += delta / 2
        self.table = special.gammaincinv(self.alpha, self.range)
    def draw(self, n_sampl, n_comp, interp='nearest'):
        if interp != 'nearest':
            raise NotImplementedError
        gamma = self.table[np.random.randint(0, self.resolution,
                                             (n_sampl, n_comp))]
        return gamma / gamma.sum(axis=1, keepdims=True)

import time, timeit

t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated           {:3f} sec'.format(timeit.timeit(
    'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
    'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))

Sample output:

Upfront cost 13.067 sec
Running cost per 1000 samples of width 4840
tabulated           0.059365 sec
np.random.dirichlet 0.980067 sec

Better check whether it is roughly correct:

enter image description here

Paul Panzer
  • 51,835
  • 3
  • 54
  • 99