2

When I summarize a dataframe and join it back on the original dataframe, then I'm having trouble working with the column names.

This is the original dataframe:

import pandas as pd

d = {'col1': ["a", "a", "b", "a", "b", "a"], 'col2': [0, 4, 3, -5, 3, 4]}
df = pd.DataFrame(data=d)

Now I calculate some statistics and merge the summary back in:

group_summary = df.groupby('col1', as_index = False).agg({'col2': ['mean', 'count']})
df = pd.merge(df, group_summary, on = 'col1')

The dataframe has some strange column names now:

df
Out: 
  col1  col2  (col2, mean)  (col2, count)
0    a     0          0.75              4
1    a     4          0.75              4
2    a    -5          0.75              4
3    a     4          0.75              4
4    b     3          3.00              2
5    b     3          3.00              2

I know I can use the columns like df.iloc[:, 2], but I would also like to use them like df['(col2, mean)'], but this returns a KeyError.

Source: This grew out of this previous question.

ulima2_
  • 1,276
  • 1
  • 13
  • 23

1 Answers1

5

It's because your GroupBy.agg operation results in a MultiIndex dataframe, and when merging a single-level header DataFrame with a MultiIndexed one, the multiIndex is converted into flat tuples.

Fix your groupby code as follows:

group_summary = df.groupby('col1', as_index=False)['col2'].agg(['mean', 'count'])

Merge now gives flat columns.

df.merge(group_summary, on='col1')

  col1  col2  mean  count
0    a     0  0.75      4
1    a     4  0.75      4
2    a    -5  0.75      4
3    a     4  0.75      4
4    b     3  3.00      2
5    b     3  3.00      2

Better still, use transform to map the output to the input dimensions.

g = df.groupby('col1', as_index=False)['col2']
df.assign(mean=g.transform('mean'), count=g.transform('count'))

  col1  col2  mean  count
0    a     0  0.75      4
1    a     4  0.75      4
2    b     3  3.00      2
3    a    -5  0.75      4
4    b     3  3.00      2
5    a     4  0.75      4

Pro-tip, you can use describe to compute some useful statistics in a single function call.

df.groupby('col1').describe()

      col2                                          
     count  mean       std  min   25%  50%  75%  max
col1                                                
a      4.0  0.75  4.272002 -5.0 -1.25  2.0  4.0  4.0
b      2.0  3.00  0.000000  3.0  3.00  3.0  3.0  3.0

Also see Get statistics for each group (such as count, mean, etc) using pandas GroupBy?

cs95
  • 379,657
  • 97
  • 704
  • 746