The problem
The ValueError
that you're seeing is, as hpaulj pointed out in the comments, due to the fact that your p
argument is wrapped in list brackets []
. This extra enclosing list
makes np.random.choice
interpret your p
as a 2D array of shape (1, x)
instead of a 1D array. Since your first argument to np.random.choice
is an integer, the function expects p
to be 1D, and so throws an error when it sees that it is not (see the choice
docs for complete details).
You can fix that ValueError
by just removing the brackets around p
:
enqueue = np.random.choice(M, in_freq[k], p=mat[k*N:k*N+M-1])
However, now your code will raise a new and different ValueError
:
ValueError: a and p must have same size
This ValueError
is caused by the fact that the way you're generating the row slices of your flattened NxM matrix mat
is not correct.
The solution
From the description in your question, it seems like what you intended was to loop over the rows of the original 2D mat
by slicing the 1D mat
. Here's how you can fix the slicing in your code so that you iterate over the rows of mat
in your for
loop:
import numpy as np
N = 10
M = 5
in_freq = np.random.randint(0, 10, N)
mat = np.random.rand(N, M)
# normalize so that each row is a probability dist
mat /= mat.sum(axis=1, keepdims=True)
# flatten to match OP's question
mat = mat.flat
for k in range(N):
print((M*k, M*(k + 1)))
enqueue = np.random.choice(M, in_freq[k], p=mat[M*k:M*(k + 1)])
Testing
In order to demonstrate that the correct slices are now being generated, I added a print statement that outputs the slice indices in each iteration of the for
loop. Here's the output from that:
(0, 5)
(5, 10)
(10, 15)
(15, 20)
(20, 25)
(25, 30)
(30, 35)
(35, 40)
(40, 45)
(45, 50)
which shows that the slices required to iterate over the rows of mat
are indeed being generated as expected now.
Notes about converting a 2D index to its 1D flattened equivalent
Given an array with N
rows and M
columns, the general formula for converting a 2D x,y
index into its flattened 1D equivalent i
is:
i = x + M*y
You can find a more in-depth discussion in this old thread.