9

I'm looking for something similar to numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])
in TensorFlow.

The closest Op to it seems to be tf.multinomial(tf.log(p)) which takes logits as input but it can't sample without replacement. Is there any other way to do sampling from a non-uniform distribution in TensorFlow?

Thanks.

Daniel Widdis
  • 8,424
  • 13
  • 41
  • 63
TNM
  • 1,361
  • 3
  • 15
  • 25

2 Answers2

4

Yes, there is. See here and here for some background information. The solution is:

z = -tf.log(-tf.log(tf.random_uniform(tf.shape(p),0,1))) 
_, indices = tf.nn.top_k(tf.log(p) + z, size)
Albert
  • 65,406
  • 61
  • 242
  • 386
2

You could just use tf.py_func to wrap numpy.random.choice and make it available as a TensorFlow op:

a = tf.placeholder(tf.float32)
size = tf.placeholder(tf.int32)
replace = tf.placeholder(tf.bool)
p = tf.placeholder(tf.float32)

y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32)

with tf.Session() as sess:
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))

You can specify the numpy seed as usual:

np.random.seed(1)
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
np.random.seed(1)
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))
np.random.seed(1)
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]}))

would print:

[ 2.  0.]
[ 2.  1.]
[ 0.  1.]
[ 2.  0.]
[ 2.  1.]
[ 0.  1.]
[ 2.  0.]
pfm
  • 6,210
  • 4
  • 39
  • 44
  • 1
    Using `tf.py_func` is terrible since calling python operation slows computation down extremely. Especially if you are using GPU, then utilization can drop from 100% to 5% depending on a task. Because I think of no new op that can do that, the best is to write custom c++ op. You can however hack it for example by iteratively sampling one element from an array, and removing that element by something like `tf.gather`. – user2781994 Feb 20 '19 at 19:18