66

In MATLAB it is easy to find the indices of values that meet a particular condition:

>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2)     % find the indecies where this condition is true
[3, 6, 9]          % (MATLAB uses 1-based indexing)
>> a(find(a > 2))  % get the values at those locations
[3, 3, 3]

What would be the best way to do this in Python?

So far, I have come up with the following. To just get the values:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]

But if I want the index of each of those values it's a bit more complicated:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]

Is there a better way to do this in Python, especially for arbitrary conditions (not just 'val > 2')?

I found functions equivalent to MATLAB 'find' in NumPy but I currently do not have access to those libraries.

Michael Currie
  • 13,721
  • 9
  • 42
  • 58
user344226
  • 663
  • 1
  • 5
  • 5

9 Answers9

95

in numpy you have where :

>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13,  1, 15,  8,  0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
joaquin
  • 82,968
  • 29
  • 138
  • 152
  • 6
    +1 You might also mention that you can index numpy arrays with boolean arrays, the same as you can in matlab. (e.g. `x[x>3]` instead of `np.where(x>3)`) (Not that there's anything wrong with `where`! The direct indexing may just be a more familiar form to people familiar with Matlab.) – Joe Kington May 11 '11 at 00:16
  • 3
    This is a good way, but the asker specified that he or she can't use numpy. – JasonFruit May 11 '11 at 00:33
  • @JasonFruit, you're right. I didnt get it when reading the question. I was blinded by the idea that the OP wanted to find the equivalent of a matlab function (and matlab is also big). By the way, in which situation could you have no access to numpy? – joaquin May 11 '11 at 07:42
  • Only way I can see is if your boss won't let you use it, or you're on a strange operating system or architecture. – JasonFruit May 11 '11 at 10:55
  • It looks like `where` actually returns indices, at least in version 1.6.1. It can return values if you specify it as the second argument. From docs on `argwhere`: "The output of argwhere is not suitable for indexing arrays. For this purpose use where(a) instead." – eacousineau Jan 15 '14 at 03:25
  • @eacousineau I reproduced my examples with 1.7.1 and got same results. Will check docs when having some time. – joaquin Jan 15 '14 at 08:12
  • Ah, I think it might be because you were using `x = np.arange(9)`, which in itself is a set of indices. If you change the array (offset it, scale it, reverse it, etc.), the output of `where` will still be indices: >>> x = np.arange(9)[::-1] * 10 + 33; print(x); print(np.where(x > 53)); [113 103 93 83 73 63 53 43 33] (array([0, 1, 2, 3, 4, 5]),) – eacousineau Jan 15 '14 at 12:38
  • @eacousineau You are right!. `where` is actually giving indexes, not values. The example was not appropriate. I am going to edit the question to fix the problem asap. Nobody saw that before. Incredible. – joaquin Jan 15 '14 at 13:36
31

You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)

>>> inds
[2, 5, 8]

It's a little closer to your Matlab example, without having to load up all of numpy.

John
  • 1,143
  • 7
  • 10
  • 2
    Think the question contains code better than this version: `inds = [i for (i, val) in enumerate(a) if val > 2]` which is a one-line solution. – beahacker Feb 07 '17 at 13:40
9

Or use numpy's nonzero function:

import numpy as np
a    = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds] 
array([3, 4, 5])
vincentv
  • 175
  • 2
  • 6
5

Why not just use this:

[i for i in range(len(a)) if a[i] > 2]

or for arbitrary conditions, define a function f for your condition and do:

[i for i in range(len(a)) if f(a[i])]
JasonFruit
  • 7,764
  • 5
  • 46
  • 61
4

The numpy routine more commonly used for this application is numpy.where(); though, I believe it works the same as numpy.nonzero().

import numpy
a    = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)

To get the values, you can either store the indices and slice withe them:

a[inds]

or you can pass the array as an optional parameter:

numpy.where(a>2, a)

or multiple arrays:

b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
ryanjdillon
  • 17,658
  • 9
  • 85
  • 110
3

I've been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):

a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]

It turns out that this is much faster than:

inds = [i for (i, val) in enumerate(a) if val > 2]

It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.

Edit:

I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:

inds = np.arange( len(a) )[ a < 2 ]
Nate
  • 2,940
  • 3
  • 22
  • 24
3

To get values with arbitrary conditions, you could use filter() with a lambda function:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]

One possible way to get the indices would be to use enumerate() to build a tuple with both indices and values, and then filter that:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
Blair
  • 15,356
  • 7
  • 46
  • 56
2

I think I may have found one quick and simple substitute. BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.

import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a

>> [[ 1.36406736  1.45217257 -0.06896245  0.98429727 -0.59281957]]

idx = mlab.find(a<0)
print idx
type(idx)

>> [2 4]
>> np.ndarray

Best, Da

DidasW
  • 31
  • 5
0

Matlab's find code has two arguments. John's code accounts for the first argument but not the second. For instance, if you want to know where in the index the condition is satisfied: Mtlab's function would be:

find(x>2,1)

Using John's code, all you have to do is add a [x] at the end of the indices function, where x is the index number you're looking for.

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument

which returns >>> 2, the first index to exceed 2.