0

Given the following matrix,

In [0]: a = np.array([[1,2,9,4,2,5],[4,5,1,4,2,4],[2,3,6,7,8,9],[5,6,7,4,3,6]])
Out[0]: 
array([[1, 2, 9, 4, 2, 5],
       [4, 5, 1, 4, 2, 4],
       [2, 3, 6, 7, 8, 9],
       [5, 6, 7, 4, 3, 6]])

I want to get the indices of the rows that have 9 as a member. This is,

idx = [0,2]

Currently I am doing this,

def myf(x):
    if any(x==9):
        return True
    else:
        return False

aux = np.apply_along_axis(myf, axis=1, arr=a)
idx = np.where(aux)[0]

And I get the result I wanted.

In [1]: idx
Out[1]: array([0, 2], dtype=int64)

But this method is very slow (meaning maybe there is a faster way) and certainly not very pythonic.

How can I do this in a cleaner, more pythonic but mainly more efficient way?

Note that this question is close to this one but here I want to apply the condition on the entire row.

myradio
  • 1,703
  • 1
  • 15
  • 25
  • 2
    Use `(a==9).any(axis=1)` and then `np.where`. – Divakar Feb 23 '20 at 20:57
  • 1
    A non-numPy solution would be something like `[index for index,row in enumerate(array) if 9 in row]` (result: `[0,2]`). I wonder if there'd be a speed difference, then. – Jongware Feb 23 '20 at 20:59
  • @Divakar that's exactly what I was looking for. I was close actually, I tried but for some reason I was putting this after a`np.where`. – myradio Feb 24 '20 at 08:23

3 Answers3

1

Use np.argwhere to find the indices where a==9 and use the 0th column of those indices to index a:

In [171]: a = np.array([[1,2,9,4,2,5],[4,5,1,4,2,4],[2,3,6,7,8,9],[5,6,7,4,3,6]])
     ...: 
     ...: indices = np.argwhere(a==9)
     ...: a[indices[:,0]]
Out[171]: 
array([[1, 2, 9, 4, 2, 5],
       [2, 3, 6, 7, 8, 9]])

...or if you just need the row numbers just save indices[:,0]. If 9 can appear more than once per row and you don't want duplicate rows listed, you can use np.unique to filter your result (does nothing for this example):

In [173]: rows = indices[:,0]

In [174]: np.unique(rows)
Out[174]: array([0, 2])
salt-die
  • 806
  • 6
  • 7
1

You could combine np.argwhere and np.any:

np.argwhere(np.any(a==9,axis=1))[:,0]
Alain T.
  • 40,517
  • 4
  • 31
  • 51
  • This is good, but still I think what @Divakar suggested in the comments is cleaner. Do you expect them to be equivalent performance wise? – myradio Feb 24 '20 at 08:24
  • 1
    @Divakar didn't write a complete solution but his approach would indeed be faster: `np.where((a==9).any(axis=1))[0]` on the sample matrix – Alain T. Feb 24 '20 at 13:20
0

You may try np.nonzero and unique

Check on 9

np.unique((a == 9).nonzero()[0])

Out[356]: array([0, 2], dtype=int64)

Check on 6

np.unique((a == 6).nonzero()[0])

Out[358]: array([2, 3], dtype=int64)

Check on 8

np.unique((a == 8).nonzero()[0])

Out[359]: array([2], dtype=int64)

On non-existent number, return empty list

np.unique((a == 88).nonzero()[0])

Out[360]: array([], dtype=int64)
Andy L.
  • 24,909
  • 4
  • 17
  • 29
  • I am not sure if I follow but I think this is not getting the intended results. What I want is a list of all the rows in which my value occurs at least once. For `9` is `[0,2]`, for `6` is `[2,3]` and for `8` is `[2]`. – myradio Feb 24 '20 at 08:22
  • @myradio: ah, I misunderstood the question. I edited the answer with different method to get your desired output. Check my updated answer – Andy L. Feb 24 '20 at 11:02