37

Setup

Suppose I have

bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
k = 3

I need the position of maximal values by unique bin in bins.

# Bin == 0
#  ↓ ↓           ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#  ↑ ↑           ↑
#  ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 8 and happens at position 0

(vals * (bins == 0)).argmax()

0

# Bin == 1
#      ↓ ↓         ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#      ↑ ↑         ↑
#        ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 4 and happens at position 3

(vals * (bins == 1)).argmax()

3

# Bin == 2
#          ↓ ↓ ↓     ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
#          ↑ ↑ ↑     ↑
#                    ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 9 and happens at position 9

(vals * (bins == 2)).argmax()

9

Those functions are hacky and aren't even generalizable for negative values.

Question

How do I get all such values in the most efficient manner using Numpy?

What I've tried.

def binargmax(bins, vals, k):
  out = -np.ones(k, np.int64)
  trk = np.empty(k, vals.dtype)
  trk.fill(np.nanmin(vals) - 1)

  for i in range(len(bins)):
    v = vals[i]
    b = bins[i]
    if v > trk[b]:
      trk[b] = v
      out[b] = i

  return out

binargmax(bins, vals, k)

array([0, 3, 9])

LINK TO TESTING AND VALIDATION

piRSquared
  • 285,575
  • 57
  • 475
  • 624
  • 1
    So, k is always no. of unique bins? – Divakar Aug 24 '18 at 14:51
  • Yes and should be the same as `bins.max() + 1` – piRSquared Aug 24 '18 at 14:51
  • Are the values guaranteed to be unique per bin? Do you want all maxima? – user3483203 Aug 24 '18 at 14:55
  • how would you treat ties/draws? are you interested in a native python solution also? – Chris_Rands Aug 24 '18 at 14:56
  • 1
    Not guaranteed, I want the first position. Like `np.array([1, 2, 2]).argmax()` returns `1` @user3483203 – piRSquared Aug 24 '18 at 14:56
  • @Chris_Rands I'm interested in all solutions. However I want to know a good way to do it with Numpy. My suspicion is that I can use some combination of `argsort` and `np.maximum.at` but couldn't figure it out. I can use `numba` with the function I offered as my attempt and get good performance. But what I really want is a Numpy solution using good technique. That said, I'm still curious and would appreciate all intelligent solutions. – piRSquared Aug 24 '18 at 14:59
  • 1
    Sure... (-: Sorry I missed it. Done! – piRSquared Aug 24 '18 at 16:38
  • Would you mind test the speed of my solution Sir ? – BENY Aug 24 '18 at 19:48

7 Answers7

19

The numpy_indexed library:

I know this isn't technically numpy, but the numpy_indexed library has a vectorized group_by function which is perfect for this, just wanted to share as an alternative I use frequently:

>>> import numpy_indexed as npi
>>> npi.group_by(bins).argmax(vals)
(array([0, 1, 2]), array([0, 3, 9], dtype=int64))

Using a simple pandas groupby and idxmax:

df = pd.DataFrame({'bins': bins, 'vals': vals})
df.groupby('bins').vals.idxmax()

Using a sparse.csr_matrix

This option is very fast on very large inputs.

sparse.csr_matrix(
    (vals, bins, np.arange(vals.shape[0]+1)), (vals.shape[0], k)
).argmax(0)

# matrix([[0, 3, 9]])

Performance

Functions

def chris(bins, vals, k):
    return npi.group_by(bins).argmax(vals)

def chris2(df):
    return df.groupby('bins').vals.idxmax()

def chris3(bins, vals, k):
    sparse.csr_matrix((vals, bins, np.arange(vals.shape[0] + 1)), (vals.shape[0], k)).argmax(0)

def divakar(bins, vals, k):
    mx = vals.max()+1

    sidx = bins.argsort()
    sb = bins[sidx]
    sm = np.r_[sb[:-1] != sb[1:],True]

    argmax_out = np.argsort(bins*mx + vals)[sm]
    max_out = vals[argmax_out]
    return max_out, argmax_out

def divakar2(bins, vals, k):
    last_idx = np.bincount(bins).cumsum()-1
    scaled_vals = bins*(vals.max()+1) + vals
    argmax_out = np.argsort(scaled_vals)[last_idx]
    max_out = vals[argmax_out]
    return max_out, argmax_out


def user545424(bins, vals, k):
    return np.argmax(vals*(bins == np.arange(bins.max()+1)[:,np.newaxis]),axis=-1)

def user2699(bins, vals, k):
    res = []
    for v in np.unique(bins):
        idx = (bins==v)
        r = np.where(idx)[0][np.argmax(vals[idx])]
        res.append(r)
    return np.array(res)

def sacul(bins, vals, k):
    return np.lexsort((vals, bins))[np.append(np.diff(np.sort(bins)), 1).astype(bool)]

@njit
def piRSquared(bins, vals, k):
    out = -np.ones(k, np.int64)
    trk = np.empty(k, vals.dtype)
    trk.fill(np.nanmin(vals))

    for i in range(len(bins)):
        v = vals[i]
        b = bins[i]
        if v > trk[b]:
            trk[b] = v
            out[b] = i

    return out

Setup

import numpy_indexed as npi
import numpy as np
import pandas as pd
from timeit import timeit
import matplotlib.pyplot as plt
from numba import njit
from scipy import sparse

res = pd.DataFrame(
       index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
       columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
       dtype=float
)

k = 5

for f in res.index:
    for c in res.columns:
        bins = np.random.randint(0, k, c)
        k = 5
        vals = np.random.rand(c)
        df = pd.DataFrame({'bins': bins, 'vals': vals})
        stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
        setp = 'from __main__ import bins, vals, k, df, {}'.format(f)
        res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

Results

enter image description here

Results with a much larger k (This is where broadcasting gets hit hard):

res = pd.DataFrame(
       index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
       columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
       dtype=float
)

k = 500

for f in res.index:
    for c in res.columns:
        bins = np.random.randint(0, k, c)
        vals = np.random.rand(c)
        df = pd.DataFrame({'bins': bins, 'vals': vals})
        stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
        setp = 'from __main__ import bins, vals, df, k, {}'.format(f)
        res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

enter image description here

As is apparent from the graphs, broadcasting is a nifty trick when the number of groups is small, however the time complexity/memory of broadcasting increases too fast at higher k values to make it highly performant.

user3483203
  • 50,081
  • 9
  • 65
  • 94
  • Mind adding the timing of mine? – BENY Aug 24 '18 at 19:46
  • 2
    Nice benchmark! As the author of numpy_indexed, let me note that the library is optimized to be 'numpythonic' and generic. That is, your bins need not be ints starting at 0; but could be any type, and any dimension ndarray infact. That does add a little overhead here and there, but if performance is your primary goal then indeed there is no arguing with numba for this type of problem. Still good to have a reference implementation with a simple API to test your low level code against though! – Eelco Hoogendoorn Aug 25 '18 at 11:14
  • 1
    Very nice use of sparse. You’ve given me two good ideas for my tool box. – piRSquared Aug 25 '18 at 18:10
  • You might want to test using a CSR vs. a CSC sparse matrix here. Because of the type of operating being done, one might be faster. I think the arguments are just about the same. I’ll post the CSC solution when I’m at a computer. – user3483203 Aug 25 '18 at 19:43
18

Here's one way by offsetting each group data so that we could use argsort on the entire data in one go -

def binargmax_scale_sort(bins, vals):
    w = np.bincount(bins)
    valid_mask = w!=0
    last_idx = w[valid_mask].cumsum()-1
    scaled_vals = bins*(vals.max()+1) + vals
    #unique_bins = np.flatnonzero(valid_mask) # if needed
    return len(bins) -1 -np.argsort(scaled_vals[::-1], kind='mergesort')[last_idx]
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • 1
    @piRSquared Suggestion for better solutions - Won't it better to have the unique bins being outputted for the cases where the bins don't cover the range `0-bins.max()`? – Divakar Aug 24 '18 at 19:04
  • @piRSquared Yeah I am testing out with `bins, vals = gen_arrays(5000, 10000)` and my modified solution only covers for the unique ones, not the entire range and hence mismatching against `binargmax`. – Divakar Aug 24 '18 at 19:07
  • 1
    I was expecting your scaled_vals as I’ve seen you use it before. Using cumsum to derive last_idx in anticipation of slicing the result of argsort!? Brilliant! Though I loathe the sort, I can’t deny the ingenuity. – piRSquared Aug 25 '18 at 18:08
  • @piRSquared Discovering I could use bincount to get the last indices per group was one of the little *eureka* moments I must admit. Interesting Q&A for sure this one. Geting argmax(first ones) needed some more brain work :) – Divakar Aug 25 '18 at 18:11
11

Okay, here's my linear-time entry, using only indexing and np.(max|min)inum.at. It assumes bins go up from 0 to max(bins).

def via_at(bins, vals):
    max_vals = np.full(bins.max()+1, -np.inf)
    np.maximum.at(max_vals, bins, vals)
    expanded = max_vals[bins]
    max_idx = np.full_like(max_vals, np.inf)
    np.minimum.at(max_idx, bins, np.where(vals == expanded, np.arange(len(bins)), np.inf))
    return max_vals, max_idx
DSM
  • 342,061
  • 65
  • 592
  • 494
9

How about this:

>>> import numpy as np
>>> bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
>>> vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
>>> k = 3
>>> np.argmax(vals*(bins == np.arange(k)[:,np.newaxis]),axis=-1)
array([0, 3, 9])
user545424
  • 15,713
  • 11
  • 56
  • 70
  • 1
    That is clever (-: The time complexity and memory demand will blow up with large k (I think). – piRSquared Aug 24 '18 at 15:12
  • @piRSquared, I've put in some benchmarks for that. With 30 or so bins it works great, with 1000 performance drops. With only 3 bins it's by far the fastest answer. – user2699 Aug 24 '18 at 15:56
  • I'm doing the same. This should be linear by length of `vals`. My initial approach is the fastest when I apply Numba's `njit`. I'll show it. I wanted an O(n) Numpy approach. This does come close. – piRSquared Aug 24 '18 at 15:58
8

If you're going for readability, this might not be the best solution, but I think it works

def binargsort(bins,vals):
    s = np.lexsort((vals,bins))
    s2 = np.sort(bins)
    msk = np.roll(s2,-1) != s2
    # or use this for msk, but not noticeably better for performance:
    # msk = np.append(np.diff(np.sort(bins)),1).astype(bool)
    return s[msk]

array([0, 3, 9])

Explanation:

lexsort sorts the indices of vals according to the sorted order of bins, then by the order of vals:

>>> np.lexsort((vals,bins))
array([7, 1, 0, 8, 2, 3, 4, 5, 6, 9])

So then you can mask that by where sorted bins differ from one index to the next:

>>> np.sort(bins)
array([0, 0, 0, 1, 1, 1, 2, 2, 2, 2])

# Find where sorted bins end, use that as your mask on the `lexsort`
>>> np.append(np.diff(np.sort(bins)),1)
array([0, 0, 1, 0, 0, 1, 0, 0, 0, 1])

>>> np.lexsort((vals,bins))[np.append(np.diff(np.sort(bins)),1).astype(bool)]
array([0, 3, 9])
sacuL
  • 49,704
  • 8
  • 81
  • 106
  • See the validation section of my link in question. This is returning the last position of the max. – piRSquared Aug 24 '18 at 17:12
  • hmm... my edited solution (using `s2 = np.sort(bins); msk = np.roll(s2,-1) != s2`) passes the first 2 validations but not the third... not sure what's going on, trying to figure that out. – sacuL Aug 24 '18 at 17:26
7

This is a fun little problem to solve. My approach is to to get an index into vals based on the values in bins. Using where to get the points where the index is True in combination with argmax on those points in vals gives the resulting value.

def binargmaxA(bins, vals):
    res = []
    for v in unique(bins):
        idx = (bins==v)
        r = where(idx)[0][argmax(vals[idx])]
        res.append(r)
    return array(res)

It's possible to remove the call to unique by using range(k) to get possible bin values. This speeds things up, but still leaves it with poor performance as the size of k increases.

def binargmaxA2(bins, vals, k):
    res = []
    for v in range(k):
        idx = (bins==v)
        r = where(idx)[0][argmax(vals[idx])]
        res.append(r)
    return array(res)

Last try, comparing each value slows things down substantially. This version computes the sorted array of values, rather than making a comparison for each unique value. Well, it actually computes the sorted indices and only gets the sorted values when needed, as that avoids one time loading vals into memory. Performance still scales with the number of bins, but much slower than before.

def binargmaxB(bins, vals):
    idx = argsort(bins)   # Find sorted indices
    split = r_[0, where(diff(bins[idx]))[0]+1, len(bins)]  # Compute where values start in sorted array
    newmax = [argmax(vals[idx[i1:i2]]) for i1, i2 in zip(split, split[1:])]  # Find max for each value in sorted array
    return idx[newmax +split[:-1]] # Convert to indices in unsorted array

Benchmarks

Here's some benchmarks with the other answers.

3000 elements

With a somewhat larger dataset (bins = randint(0, 30, 3000); vals = randn(3000); k=30;)

  • 171us binargmax_scale_sort2 by Divakar
  • 209us this answer, version B
  • 281us binargmax_scale_sort by Divakar
  • 329us broadcast version by user545424
  • 399us this answer, version A
  • 416us answer by sacul, using lexsort
  • 899us reference code by piRsquared

30000 elements

And an even larger dataset (bins = randint(0, 30, 30000); vals = randn(30000); k=30). Surprisingly this doesn't change the relative performance between solutions.

  • 1.27ms this answer, version B
  • 2.01ms binargmax_scale_sort2 by Divakar
  • 2.38ms broadcast version by user545424
  • 2.68ms this answer, version A
  • 5.71ms answer by sacul, using lexsort
  • 9.12ms reference code by piRSquared

Edit I didn't change k with the increasing number of possible bin values, now that I've fixed that the benchmarks are more even.

1000 bin values

Increasing the number unique bin values may also have an impact on performance. The solutions by Divakar and sacul are mostly unaffected, while the others have quite a substantial impact. bins = randint(0, 1000, 30000); vals = randn(30000); k = 1000

  • 1.99ms binargmax_scale_sort2 by Divakar
  • 3.48ms this answer, version B
  • 6.15ms answer by sacul, using lexsort
  • 10.6ms reference code by piRsquared
  • 27.2ms this answer, version A
  • 129ms broadcast version by user545424

Edit Including benchmarks for the reference code in the question, it's surprisingly competitive especially with more bins.

user2699
  • 2,927
  • 14
  • 31
3

I know you said to use Numpy, but if Pandas is acceptable:

import numpy as np; import pandas as pd;
(pd.DataFrame(
    {'bins':np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2]),
     'values':np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])}) 
.groupby('bins')
.idxmax())

      values
bins        
0          0
1          3
2          9
user1717828
  • 7,122
  • 8
  • 34
  • 59