Suppose i have a 2D numpy array. Given n, i wish to nulify all elements in the matrix except the top n.
I've tried idx = (-y_pred).argsort(axis=-1)[:, :n]
to determine what are the indices of the largest n values, but idx
shape is [H,W,n], and i don't understand why.
I've tried -
sorted_list = sorted(y_pred, key=lambda x: x[0], reverse=True)
top_ten = sorted_list[:10]
But it didn't really return top 10 indices.
Is there an efficient way to find top n indices and zero the rest?
EDIT input is a NxM matrix of values, and output is the same matrix of size NxM, such that all values are 0 except in indices that correspond to top 10 values