3

I have this method where I grab the columns from a pandas dataframe depending on the labels but indexing via numpy is much faster.

Is there a way in pandas or numpy to go from column labels to column indices without iterating?

DF_var = pd.DataFrame(np.random.random((5,10)), columns=["attr_%d" % _ for _ in range(10)])
query_cols = ["attr_2","attr_5","attr_6","attr_0"]
want_idx = [0,2,5,6]

# Something like np.where w/o iterating through? 
# np.where(query_cols in DF_var.columns)
# TypeError: unhashable type: 'list'

# np.where(x in DF_var.columns for x in query_cols)
# (array([0]),)


long_way = list()
for i, label in enumerate(DF_var.columns):
    if label in query_cols:
        long_way.append(i)
# print(sorted(long_way))
# [0, 2, 5, 6]

enter image description here

O.rka
  • 29,847
  • 68
  • 194
  • 309
  • 2
    See: http://stackoverflow.com/questions/13021654/retrieving-column-index-from-column-name-in-python-pandas – albert Jul 20 '16 at 19:06
  • Is this for a single value or for a list? – O.rka Jul 20 '16 at 19:14
  • @O.rka single, but you can use a list comprehension to get all the indices. – Alex Jul 20 '16 at 19:15
  • 1
    This does the job using numpy `np.argwhere(DF_var.columns.isin(["attr_2","attr_5","attr_6","attr_0"])).flatten()` – sirfz Jul 20 '16 at 19:20
  • 1
    @O.rka You can use `searchsorted method`. New answer posted into the linked dup target. – Divakar Jul 20 '16 at 19:38
  • @Divakar hey thanks, do you have to turn into a `pd.categorical object`? http://pandas.pydata.org/pandas-docs/version/0.18.1/generated/pandas.Index.searchsorted.html – O.rka Jul 20 '16 at 20:07
  • Well easiest way was to extract the column names as array and then using NumPy's searchsorted func. Here's the [`solution`](http://stackoverflow.com/a/38489403/3293881). – Divakar Jul 20 '16 at 20:11
  • Since, you don't care about the order, you can simply do : `np.where(np.in1d(df.columns,query_cols))[0]`. – Divakar Jul 20 '16 at 20:15

1 Answers1

3
short_way = [df.columns.get_loc(col) for col in query_cols]
print(sorted(short_way))
# outputs [0, 2, 5, 6]
Alex
  • 18,484
  • 8
  • 60
  • 80