When and why is the sort flag of a DataFrame grouping ignored in pd.GroupBy.apply()
? The problem is best understood with an example. In the following 4 equivalent solutions to a dummy problem, approaches 1 and 4 observe the sort flag, while approaches 2 and 3 ignore it for some reason.
import pandas as pd
import numpy as np
#################################################
# Construct input data:
cats = list("bcabca")
vals = np.arange(0,10*len(cats),10)
df = pd.DataFrame({"i": cats, "ii": vals})
# df:
# i ii
# 0 b 0
# 1 c 10
# 2 a 20
# 3 b 30
# 4 c 40
# 5 a 50
# Groupby with sort=True
g = df.groupby("i", sort=True)
#################################################
# 1) This correctly returns a sorted series
ret1 = g.apply(lambda df: df["ii"]+1)
# ret1:
# i
# a 2 21
# 5 51
# b 0 1
# 3 31
# c 1 11
# 4 41
#################################################
# 2) This ignores the sort flag
ret2 = g.apply(lambda df: df[["ii"]]+1)
# ret2:
# ii
# 0 1
# 1 11
# 2 21
# 3 31
# 4 41
# 5 51
#################################################
# 3) This also ignores the sort flag.
def fun(df):
df["iii"] = df["ii"] + 1
return df
ret3 = g.apply(fun)
# ret3
# i ii iii
# 0 b 0 1
# 1 c 10 11
# 2 a 20 21
# 3 b 30 31
# 4 c 40 41
# 5 a 50 51
#################################################
# 4) This, however, respects the sort flag again:
ret4 = {}
for key, dfg in g:
ret4[key] = fun(dfg)
ret4 = pd.concat(ret4, axis=0)
# ret4:
# i ii iii
# a 2 a 20 21
# 5 a 50 51
# b 0 b 0 1
# 3 b 30 31
# c 1 c 10 11
# 4 c 40 41
Is this a design flaw in pandas? Or is this behavior intentional? From the documentation of pd.DataFrame.groupby()
and pd.GroupBy.apply()
, I would expect solutions 2 and 3 to also take the sort flag into account. Why would they not?
(The problem was reproduced with pandas 1.2.4 and 1.4.0)
Update: A workaround for approaches 2 and 3 is to first sort the DataFrame by the grouping key. Source of inspiration: See link in the comments.
# Approach 2:
df.sort_values("i").groupby("i").apply(lambda df: df[["ii"]]+1)
# Approach 3:
df.sort_values("i").groupby("i").apply(fun)