5

I see a lot of SO topics on related topics but none of them provides the efficient way.

I want to find the k-th smallest element (or median) on 2D array [1..M][1..N] where each row is sorted in ascending order and all elements are distinct.

I think there is O(M log MN) solution, but I have no idea about implementation. (Median of Medians or Using Partition with Linear Complexity is some method but no idea any more...).

This is an old Google interview question and can be searched on Here.

But now I want hint or describe the most efficient algorithm (the fastest one).

Also I read a paper on here but I don't understand it.

Update 1: one solution is found here but when dimension is odd.

Kedar Mhaswade
  • 4,535
  • 2
  • 25
  • 34
  • 1
    You will probably get more insighful answers on [the computer science stackexchange](https://cs.stackexchange.com) – Stef Nov 19 '20 at 18:10
  • I found very perfect solution here, and I think this is much more community that CS. @Stef thanks. –  Nov 19 '20 at 20:46
  • 2
    Are you asking about sorted rows ONLY, or sorted rows AND columns. Your description and bound look reasonable for sorted rows ONLY. But all of your links are to sorted rows AND columns. – btilly Nov 19 '20 at 21:47
  • @btilly wow, thanks I read lots of your nice answer about sth like my problem. Just Rows is sorted. we are nothing know about column. (because not mentioned in the interview question). I add links because I think this is more specific case from those. not sure. –  Nov 19 '20 at 22:14
  • 1
    @Spektre there are a solution that find answer in O( M log MN). i think this is lower bound vs your time? isnt it? please add you answer here. –  Nov 27 '20 at 12:40
  • @MokholiaPokholia I converted it to answer ... May be I misunderstood your problem or all others overthinking this (less likely) – Spektre Nov 27 '20 at 14:23

5 Answers5

5

So to solve this problem, it helps to solve a slightly different one. We want to know the upper/lower bounds in each row for where the overall k'th cutoff is. Then we can go through, verify that the number of things at or below the lower bounds is < k, the number of things at or below the upper bounds is > k, and there is only one value between them.

I've come up with a strategy for doing a binary search in all rows simultaneously for those bounds. Being a binary search it "should" take O(log(n)) passes. Each pass involves O(m) work for a total of O(m log(n)) times. I put should in quotes because I don't have a proof that it actually takes O(log(n)) passes. In fact it is possible to be too aggressive in a row, discover from other rows that the pivot chosen was off, and then have to back off. But I believe that it does very little backing off and actually is O(m log(n)).

The strategy is to keep track in each row of a lower bound, an upper bound, and a mid. Each pass we make a weighted series of ranges to lower, lower to mid, mid to upper, and upper to the end with the weight being the number of things in it and the value being the last in the series. We then find the k'th value (by weight) in that data structure, and use that as a pivot for our binary search in each dimension.

If a pivot winds up out of the range from lower to upper, we correct by widening the interval in the direction that corrects the error.

When we have the correct sequence, we've got an answer.

There are a lot of edge cases, so staring at full code may help.

I also assume that all elements of each row are distinct. If they are not, you can get into endless loops. (Solving that means even more edge cases...)

import random

# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
    # This does quickselect for average O(len(pairs)).
    # Median of medians is deterministically the same, but a bit slower
    pivot = pairs[int(random.random() * len(pairs))][0]

    # Which side of our answer is the pivot on?
    weight_under_pivot = 0
    pivot_weight = 0
    for value, weight in pairs:
        if value < pivot:
            weight_under_pivot += weight
        elif value == pivot:
            pivot_weight += weight

    if weight_under_pivot + pivot_weight < k:
        filtered_pairs = []
        for pair in pairs:
            if pivot < pair[0]:
                filtered_pairs.append(pair)
        return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
    elif k <= weight_under_pivot:
        filtered_pairs = []
        for pair in pairs:
            if pair[0] < pivot:
                filtered_pairs.append(pair)
        return weighted_kth (k, filtered_pairs)
    else:
        return pivot

# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
    # The strategy is to discover the k'th value, and also discover where
    # that would be in each row.
    #
    # For each row we will track what we think the lower and upper bounds
    # are on where it is.  Those bounds start as the start and end and
    # will do a binary search.
    #
    # In each pass we will break each row into ranges from start to lower,
    # lower to mid, mid to upper, and upper to end.  Some ranges may be
    # empty.  We will then create a weighted list of ranges with the weight
    # being the length, and the value being the end of the list.  We find
    # where the k'th spot is in that list, and use that approximate value
    # to refine each range.  (There is a chance that a range is wrong, and
    # we will have to deal with that.)
    #
    # We finish when all of the uppers are above our k, all the lowers
    # one are below, and the upper/lower gap is more than 1 only when our
    # k'th element is in the middle.

    # Our data structure is simply [row, lower, upper, bound] for each row.
    data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
    is_search = True
    while is_search:
        pairs = []
        for row, lower, upper, bound in data:
            # Literal edge cases
            if 0 == upper:
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            elif lower == bound:
                pairs.append((row[lower], lower + 1))
            elif lower + 1 == upper: # No mid.
                pairs.append((row[lower], lower + 1))
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            else:
                mid = (upper + lower) // 2
                pairs.append((row[lower], lower + 1))
                pairs.append((row[mid], mid - lower))
                pairs.append((row[upper], upper - mid))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))

        pivot = weighted_kth(k, pairs)

        # Now that we have our pivot, we try to adjust our parameters.
        # If any adjusts we continue our search.
        is_search = False
        new_data = []
        for row, lower, upper, bound in data:
            # First cases where our bounds weren't bounds for our pivot.
            # We rebase the interval and either double the range.
            # - double the size of the range
            # - go halfway to the edge
            if 0 < lower and pivot <= row[lower]:
                is_search = True
                if pivot == row[lower]:
                    new_data.append((row, lower-1, min(lower+1, bound), bound))
                elif upper <= lower:
                    new_data.append((row, lower-1, lower, bound))
                else:
                    new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
            elif upper < bound and row[upper] <= pivot:
                is_search = True
                if pivot == row[upper]:
                    new_data.append((row, upper-1, upper+1, bound))
                elif lower < upper:
                    new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
                else:
                    new_data.append((row, upper, upper+1, bound))
            elif lower + 1 < upper:
                if upper == lower+2 and pivot == row[lower+1]:
                    new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
                else:
                    # We will split this interval.
                    is_search = True
                    mid = (upper + lower) // 2
                    if row[mid] < pivot:
                        new_data.append((row, mid, upper, bound))
                    elif pivot < row[mid] pivot:
                        new_data.append((row, lower, mid, bound))
                    else:
                        # We center our interval on the pivot
                        new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
            else:
                # We look like we found where the pivot would be in this row.
                new_data.append((row, lower, upper, bound))
        data = new_data # And set up the next search
    return pivot
btilly
  • 43,296
  • 3
  • 59
  • 88
  • all elements is distinct. true consideration. –  Nov 21 '20 at 05:14
  • @MokholiaPokholia Please tell me if you find any cases where it doesn't work as promised. – btilly Nov 21 '20 at 06:34
  • very nice and let me some minutes to inspect. one question at first raise in my mind, how we can proof about complexity at first before insight into complexity? –  Nov 21 '20 at 09:49
  • a minor misunderstanding point for me. what is your time complexity? –  Nov 22 '20 at 08:51
  • @MokholiaPokholia I don't have a proof. But.I believe that the time complexity is `O(m log(n))`. I have another variant that can handle duplicates and has slightly better behavior, but again I don't have a proof of performance. (The difference is that that cuts intervals into thirds, uses the range trick to establish upper/lower bounds on the k'th value. Then throws away the parts of the row that are definitely not within bounds.) – btilly Nov 22 '20 at 21:49
  • I put a bounty. please check last part of this question (last part) from text "So I have finally googled a solution..." https://www.quora.com/What-is-the-most-efficient-fast-way-to-finding-k-th-smallest-element-or-median-on-2D-array-1-M-1-N-where-each-rows-of-this-matrix-is-sorted-in-ascending-order-and-all-elements-are-distinct –  Nov 23 '20 at 02:29
  • This algorithm bears a strong family resemblance to Mirzaian and Arjomandi (Selection in X + Y and matrices with sorted columns, 1985), the main difference here being the recovery logic, which they avoid by exploiting the ordering on both rows and columns. Since that paper followed the more complicated algorithm of Frederickson and Johnson (The complexity of selection and ranking in X + Y and matrices with sorted columns, 1982) that needed only one dimension to be sorted, I suspect that they tried and failed to do what you're claiming here, which makes me suspicious. – David Eisenstat Nov 23 '20 at 12:22
  • @DavidEisenstat I am suspicious that I don't know how to prove the time complexity. I would be surprised if it is wrong though. – btilly Nov 24 '20 at 18:26
5

Another answer has been added to provide an actual solution. This one has been left as it was due to quite the rabbit hole in the comments.


I believe the fastest solution for this is the k-way merge algorithm. It is a O(N log K) algorithm to merge K sorted lists with a total of N items into a single sorted list of size N.

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

Given a MxN list. This ends up being O(MNlog(M)). However, that is for sorting the entire list. Since you only need the first K smallest items instead of all N*M, the performance is O(Klog(M)). This is quite a bit better than what you are looking for, assuming O(K) <= O(M).

Though this assumes you have N sorted lists of size M. If you actually have M sorted lists of size N, this can be easily handled though just by changing how you loop over the data (see the pseudocode below), though it does mean the performance is O(K log(N)) instead.

A k-way merge just adds the first item of each list to a heap or other data structure with a O(log N) insert and O(log N) find-mind.

Pseudocode for k-way merge looks a bit like this:

  1. For each sorted list, insert the first value into the data structure with some means of determining which list the value came from. IE: You might insert [value, row_index, col_index] into the data structure instead of just value. This also lets you easily handle looping over either columns or rows.
  2. Remove the lowest value from the data structure and append to the sorted list.
  3. Given that the item in step #2 came from list I add the next lowest value from list I to the data structure. IE: if value was row 5 col 4 (data[5][4]). Then if you are using rows as lists, then the next value would be row 5 col 5 (data[5][5]). If you are using columns then the next value is row 6 col 4 (data[6][4]). Insert this next value into the data structure like you did #1 (ie: [value, row_index, col_index])
  4. Go back to step 2 as needed.

For your needs, do steps 2-4 K times.

Nuclearman
  • 5,029
  • 1
  • 19
  • 35
  • Comments are not for extended discussion; this conversation has been [moved to chat](https://chat.stackoverflow.com/rooms/225232/discussion-on-answer-by-nuclearman-fastest-algorithm-for-kth-smallest-element-o). – Bhargav Rao Nov 28 '20 at 03:12
2

Seems like the best way to go is a k-way merge in increasingly larger sized blocks. A k-way merge seeks to build a sorted list, but we don't need it sorted and we don't need to consider each element. Instead we'll create a semi-sorted intervals. The intervals will be sorted, but only on the highest value.

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

We use the same approach as a k-way merge, but with a twist. Basically it aims to indirectly build a semi-sorted sublist. For example instead of finding [1,2,3,4,5,6,7,8,10] to determine the K=10, it will instead find something like [(1,3),(4,6),(7,15)]. With K-way merge we consider 1 item at a time from each list. In this approach hover, when pulling from a given list, we want to first consider Z items, then 2 * Z items, then 2 * 2 * Z items, so 2^i * Z items for the i-th time. Given an MxN matrix that means it will require we pull up to O(log(N)) items from the list M times.

  1. For each sorted list, insert the first K sublists into the data structure with some means of determining which list the value came from. We want the data structure to use the highest value in the sublist we insert into it. In this case we would want something like [max_value of sublist, row index, start_index, end_index]. O(m)
  2. Remove the lowest value (this is now a list of values) from the data structure and append to the sorted list. O(log (m))
  3. Given that the item in step #2 came from list I add the next 2^i * Z values from list I to the data structure upon the i-th time pulling from that specific list (basically just double the number that was present in the sublist just removed from the data structure). O(log m)
  4. If the size of the semi-sorted sublist is greater than K, use binary search to find the kth value. O(log N)). If there are any sublists remaining in the data structure, where the min value is less than k. Goto step 1 with the lists as inputs and the new K being k - (size of semi-sorted list).
  5. If the size of the semi-sorted sublist is equal to K, return the last value in the semi-sorted sublist, this is the Kth value.
  6. If the size of the semi-sorted sublist is less than K, go back to step 2.

As for performance. Let's see here:

  • Takes O(m log m) to add the initial values to the data structure.
  • It needs to consider at most O(m) sublists each requiring O(log n) time for `O(m log n).
  • It needs perform a binary search at the end, O(log m), it may need to reduce the problem into a recursive sublists if there is uncertainty about what the value of K is (Step 4), but I don't think that'll affect the big O. Edit: I believe this just adds another O(mlog(n)) in the worst case, which has no affect on the Big O.

So looks like it's O(mlog(m) + mlog(n)) or simply O(mlog(mn)).

As an optimization, if K is above NM/2 consider the max value when you consider the min value and the min value when you would consider the max value. This will greatly increase the performance when K is close to NM.

Nuclearman
  • 5,029
  • 1
  • 19
  • 35
1

The answers by btilly and Nuclearman provide two different approaches, a kind of binary search and a k-way merge of the rows.

My proposal is to combine both methods.

  • If k is small (let's say less than M times 2 or 3) or big (for simmetry, close to N x M) enough, find the kth element with a M-way merge of the rows. Of course, we shouldn't merge all the elements, just the first k.

  • Otherwise, start inspecting the first and the last column of the matrix in order to find the minimum (witch is in the first column) and the maximum (in the last column) values.

  • Estimate a first pivotal value as a linear combination of those two values. Something like pivot = min + k * (max - min) / (N * M).

  • Perform a binary search in each row to determine the last element (the closer) not greater than the pivot. The number of elements less than or equal to the pivot is simply deduced. Comparing the sum of those with k will tell if the chosen pivot value is too big or too small and let us modify it accordingly. Keep track of the maximum value between all the rows, it may be the kth-element or just used to evaluate the next pivot. If we consider said sum as a function of the pivot, the numeric problem is now to find the zero of sum(pivot) - k, which is a monotonic (discrete) function. At worst, we can use the bisection method (logarithmic complexity) or the secant method.

  • We can ideally partition each row in three ranges:

    • At the left, the elements whitch are surely less than or equal to the kth element.
    • In the middle, the undeterminated range.
    • At the right, the elements whitch are surely greater than the kth element.
  • The undeterminate range will reduce at every iteration, eventually becoming empty for most rows. At some point, the number of elements still in the undeterminated ranges, scattered throughout the matrix, will be small enough to resort to a single M-way merge of those ranges.

  • If we consider the time complexity of a single iteration as O(MlogN), or M binary searches, we need to multiply it by the number of iterations required for the pivot to converge to the value of the kth-element, which could be O(logNM). This sum up to O(MlogNlogM) or O(MlogNlogN), if N > M.

  • Note that, if the algorithm is used to find the median, with the M-way merge as last step is easy to find the (k + 1)th-element too.

Bob__
  • 12,361
  • 3
  • 28
  • 42
  • Interesting algorithm. I was considering doing something similar but wasn't sure it would work correctly (or be more performant), so just stuck to k-way merge. I believe the partition bit was what I was missing to make it possible, so kudos for working that out. Seems like a solid approach, but not 100% sure it is correct as-is but seems close enough to be used. – Nuclearman Dec 26 '20 at 21:00
0

May be I am missing something but If your NxM matrix A have M rows are already sorted ascending with no repetition of elements then k-th smallest value of row is just picking k-th element from row which is O(1). To move to 2D you just select the k-th column instead, sort it ascending O(M.log(M)) and again pick k-th element leading to O(N.log(N)).

  1. lets have matrix A[N][M]

    where elements are A[column][row]

  2. sort k-th column of A ascending O(M.log(M))

    so sort A[k][i] where i = { 1,2,3,...M } ascending

  3. pick A[k][k] as the result

In case you want k-th smallest of all the elements in A instead then You need to exploit the already sorted rows in form similar to merge sort.

  1. create empty list c[] for holding k smallest values

  2. process columns

  3. create temp array b[]

    which holds the processed column quick sorted ascending O(N.log(N))

  4. merge c[] and b[] so c[] holds up to k smallest values

    Using temp array d[] will lead to O(k+n)

  5. if during merging was not used any item from b stop processing columns

    This can be done by adding flag array f which will hold where from b,c the value was taken during the merge and then just checking if any value was taken from b

  6. output c[k-1]

When put all together the final complexity is O(min(k,M).N.log(N)) if we consider that k is less than M we can rewrite to O(k.N.log(N)) otherwise O(M.N.log(N)). Also on average the number of columns to iterate will be even less more likely ~(1+(k/N)) so average complexity would be ~O(N.log(N)) but that is just my wild guess which might be wrong.

Here small C++/VCL example:

//$$---- Form CPP ----
//---------------------------------------------------------------------------
#include <vcl.h>
#pragma hdrstop
#include "Unit1.h"
#include "sorts.h"
//---------------------------------------------------------------------------
#pragma package(smart_init)
#pragma resource "*.dfm"
TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
    {
    int i,j,k,ii,jj,d=13,b[m];
    Randomize();
    RandSeed=0x12345678;
    // a,a0 = some distinct pseudorandom values (fully ordered asc)
    for (k=Random(d),j=0;j<n;j++)
     for (i=0;i<m;i++,k+=Random(d)+1)
      { a0[i][j]=k; a[i][j]=k; }
    // schuffle a
    for (j=0;j<n;j++)
     for (i=0;i<m;i++)
        {
        ii=Random(m);
        jj=Random(n);
        k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
        }
    // sort rows asc
    for (j=0;j<n;j++)
        {
        for (i=0;i<m;i++) b[i]=a[i][j];
        sort_asc_quick(b,m);
        for (i=0;i<m;i++) a[i][j]=b[i];
        }

    }
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
    {
    int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
    c=new int[k+k+k]; d=c+k; f=d+k;
    // handle edge cases
    if (m<1) return -1;
    if (k>m*n) return -1;
    if (m==1) return a[0][k];
    // process columns
    for (cn=0,i=0;i<m;i++)
        {
        // b[] = sorted_asc a[i][]
        for (j=0;j<n;j++) b[j]=a[i][j];     // O(n)
        sort_asc_quick(b,n);                // O(n.log(n))
        // c[] = c[] + b[] asc sorted and limited to cn size
        for (bi=0,ci=0,di=0;;)              // O(k+n)
            {
                 if ((ci>=cn)&&(bi>=n)) break;
            else if (ci>=cn)     { d[di]=b[bi]; f[di]=1; bi++; di++; }
            else if (bi>= n)     { d[di]=c[ci]; f[di]=0; ci++; di++; }
            else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
            else                 { d[di]=c[ci]; f[di]=0; ci++; di++; }
            if (di>k) di=k;
            }
        e=c; c=d; d=e; cn=di;
        for (ci=0,j=0;j<cn;j++) ci|=f[j];   // O(k)
        if (!ci) break;
        }
    k=c[k-1];
    delete[] c;
    return k;
    }
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
    {
    int i,j,k;
    AnsiString txt="";

    generate();

    txt+="a0[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);

    txt+="\r\na[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);

    k=20;
    txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
    txt+=AnsiString().sprintf("\r\n%ith smallest from a  = %4i\r\n",k,kmin(k));

    mm_log->Lines->Add(txt);
    }
//-------------------------------------------------------------------------

Just ignore the VCL stuff. Function generate computes a0, a matrices where a0 is fully sorted and a has only rows sorted and all values are distinct. The function kmin is the algo described above returning k-th smallest value from a[m][n] For sorting I used this:

template <class T> void sort_asc_quick(T *a,int n)
    {
    int i,j; T a0,a1,p;
    if (n<=1) return;                                   // stop recursion
    if (n==2)                                           // edge case
        {
        a0=a[0];
        a1=a[1];
        if (a0>a1) { a[0]=a1; a[1]=a0; }                // condition
        return;
        }
    for (a0=a1=a[0],i=0;i<n;i++)                        // pivot = midle (should be median)
        {
        p=a[i];
        if (a0>p) a0=p;
        if (a1<p) a1=p;
        } if (a0==a1) return; p=(a0+a1+1)/2;            // if the same values stop
    if (a0==p) p++;
    for (i=0,j=n-1;i<=j;)                               // regroup
        {
        a0=a[i];
        if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
        }
    sort_asc_quick(a  ,  i);                            // recursion a[]<=p
    sort_asc_quick(a+i,n-i);                            // recursion a[]> p
    }

And Here the output:

a0[][]
  10   17   29   42   54   66   74   85   90  102 
 112  114  123  129  142  145  146  150  157  161 
 166  176  184  191  195  205  213  216  222  224 
 226  237  245  252  264  273  285  290  291  296 
 309  317  327  334  336  349  361  370  381  390 
 397  398  401  411  422  426  435  446  452  462 
 466  477  484  496  505  515  522  524  525  530 
 542  545  548  553  555  560  563  576  588  590 

a[][]
 114  142  176  264  285  317  327  422  435  466 
 166  336  349  381  452  477  515  530  542  553 
 157  184  252  273  291  334  446  524  545  563 
  17  145  150  237  245  290  370  397  484  576 
  42  129  195  205  216  309  398  411  505  560 
  10  102  123  213  222  224  226  390  496  555 
  29   74   85  146  191  361  426  462  525  590 
  54   66   90  112  161  296  401  522  548  588 

20th smallest from a0 =  161

20th smallest from a  =  161

This example iterated only 5 columns...

Spektre
  • 49,595
  • 11
  • 110
  • 380
  • Very nice, How O(M log MN) can be achivable in this method? –  Nov 28 '20 at 02:02
  • @MounaMokhiab I edited my answer ... added example I just bustled together... I similarly to you was thinking that partial sorted `a` sort would lead to `O(M.log(M.N))` but looks like I was wrong as it leads to `O(M.N.log(N))` instead. However I did some tweaking (as we do not need to sort the whole matrix just first k smallest elements) hence the complexity difference .... – Spektre Nov 28 '20 at 10:48
  • Sure that we have M*N matrix means M rows and N columns such that M rows was sorted and no repeated elements was there. –  Nov 29 '20 at 18:30
  • you see in the OP definitly this definition . –  Nov 29 '20 at 18:30