Numba just in time compiling
from numba import njit
@njit
def idxmax_(bins, k, weights):
out = np.zeros(k, np.int64)
trk = np.zeros(k)
for i, w in enumerate(weights - (weights.min() - 1)):
b = bins[i]
if w > trk[b]:
trk[b] = w
out[b] = i
return np.sort(out)
def idxmax(df):
f, u = pd.factorize(df.gr)
return idxmax_(f, len(u), df.col.values)
idxmax(df)
array([ 156, 220, 258, ..., 499945, 499967, 499982])
Make sure to prime the function in order to compile it
idxmax(df.head())
Then time it
%timeit idxmax(df)
%timeit df.sort_values(['gr', 'col'], ascending=False).drop_duplicates('gr').index
6.07 ms ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
152 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Comparing equality
idx0 = df.groupby('gr').col.idxmax().sort_values().values
idx1 = idxmax(df)
idx2 = df.sort_values(
['gr', 'col'],
ascending=False
).drop_duplicates('gr').index.sort_values().values
print((idx0 == idx1).all(), (idx0 == idx2).all(), sep='\n')
True
True