0

Following the example here, I am able to find the column indices of a 2D numpy array and get back an array of column indices of all occurrences of the max value.

But now I want to do the same thing but on a sparse csr_matrix.

x = np.array([[0,0,1,0,0,0,2],[0,0,0,4,0,0,0],[0,9,1,0,0,0,2],[0,0,1,0,0,9,2]])
max_col_inds = np.argwhere(x == np.max(x))[:,1]
# array([1, 5], dtype=int64)

Then I want to get the 1st and 5th elements of a 1D array using that result:

words[max_col_inds]

If x is a 2D numpy array and words is a 1D numpy array, this works.

But now if I replace x with a scipy.sparse.csr.csr_matrix, I get this on the call to np.argwhere():

TypeError: tuple indices must be integers, not tuple
tony_tiger
  • 789
  • 1
  • 11
  • 25
  • `np.where` does not work with sparse. There is a similar `x.nonzero` (check the docs) – hpaulj Aug 15 '17 at 19:59
  • I'm sorry but I can't replicate your error. Perhaps make sure scipy.sparse.csr_matrix is being called correctly, since you refer to it incorrectly in the question. – Dylan Aug 15 '17 at 20:00

1 Answers1

1
In [804]: x = np.array([[0,0,1,0,0,0,2],[0,0,0,4,0,0,0],[0,9,1,0,0,0,2],[0,0,1,0,0,9,2]])
In [805]: np.max(x)
Out[805]: 9
In [806]: np.where(x == 9)
Out[806]: (array([2, 3], dtype=int32), array([1, 5], dtype=int32))

argwhere is just np.transpose(np.where(...)); that is, converts the tuple into a 2d array and transposes it:

In [807]: np.argwhere(x ==9)
Out[807]: 
array([[2, 1],
       [3, 5]], dtype=int32)

Doing the same thing with sparse

In [808]: xM = sparse.csr_matrix(x)
In [809]: xM == 9
Out[809]: 
<4x7 sparse matrix of type '<class 'numpy.bool_'>'
    with 2 stored elements in Compressed Sparse Row format>

np.where is the samething as np.nonzero:

In [810]: (xM==9).nonzero()
Out[810]: (array([2, 3], dtype=int32), array([1, 5], dtype=int32))
In [811]: np.transpose((xM==9).nonzero())
Out[811]: 
array([[2, 1],
       [3, 5]], dtype=int32)

Actually in the current numpy argwhere works with sparse. That's because np.nonzero delegates to the matrix method:

In [813]: np.argwhere(xM==9)
Out[813]: 
array([[2, 1],
       [3, 5]], dtype=int32)
hpaulj
  • 221,503
  • 14
  • 230
  • 353