4

i have implemented the quick select algorithm.

i have the problem that my algorithm ends up in a endless loop, when i use duplicates in the array...

can you help me to get it work?

the expected complexity is O(n) with worst case O(n^2)?

#include <iostream> 
#include <vector> 
#include <algorithm> 
#include <ctime> 
using namespace std; 

int rand_partition(vector<int> &a, int left, int right) { 
    int pivotIndex = left + (rand() % (right - left)); 
    //int m = left + (right - left) / 2; //... to test the algo...no rand at this point 
    int pivot = a[pivotIndex]; 
    int i = left; 
    int j = right; 

    do { 
        while (a[i] < pivot) i++; // find left element > pivot 
        while (a[j] > pivot) j--; // find right element < pivot 

        // if i and j not already overlapped, we can swap 
        if (i < j) { 
            swap(a[i], a[j]); 
        } 
    } while (i < j); 

    return i; 
} 

// Returns the n-th smallest element of list within left..right inclusive (i.e. n is zero-based). 
int quick_select(vector<int> &a, int left, int right, int n) { 
    if (left == right) {        // If the list contains only one element 
        return a[left];  // Return that element 
    } 

    int pivotIndex = rand_partition(a, left, right); 

    // The pivot is in its final sorted position 
    if (n == pivotIndex) { 
        return a[n]; 
    } 
    else if (n < pivotIndex) { 
        return quick_select(a, left, pivotIndex - 1, n); 
    } 
    else { 
        return quick_select(a, pivotIndex + 1, right, n); 
    } 
} 

int main() { 

    vector<int> vec= {1, 0, 3, 5, 0, 8, 6, 0, 9, 0}; 

    cout << quick_select(vec, 0, vec.size() - 1, 5) << endl; 

    return 0; 
}
dreamlax
  • 93,976
  • 29
  • 161
  • 209
Gerald
  • 133
  • 4
  • 9
  • 2
    To test your algorithm, sort all of these one at a time: {}, {0}, {0,1}, {1,0}, {0,0}, {0,1,2}, {0,2,1}, {1,0,2}, {1,2,0}, {2,0,1}, {2,1,0}, {0,0,0}, {0,0,1}, {0,1,0}, {1,0,0}, {0,1,1}, {1,0,1}, {1,1,0}. If it can correctly sort all of these, you're probably done. – Mooing Duck Feb 27 '14 at 20:53
  • First thing I see is that for certain cases, you iterate past the beginning and/or past the end looking for nodes to swap. – Mooing Duck Feb 27 '14 at 20:54
  • The second thing I see is that your partition seems to get stuck in an infinite loops when there's duplicates of the pivot. When it's looping, attach with a debugger and pause it, and then step through and watch why it's in the loop. – Mooing Duck Feb 27 '14 at 20:56
  • @Xarn: this is a select, not a sort. It can (and normally does) have O(n) complexity. – Jerry Coffin Feb 27 '14 at 21:13
  • @JerryCoffin I read that as a sort, noted. – Xarn Feb 27 '14 at 21:34
  • I notice a complete lack of asserts. Your code should be full of asserts and/or debugging dumps. I mean, you decrement some index -- what happens if you decrement past `0`? Or increment past the end? If you had asserts that the indexes where still within the bounds of the array, at least you'd get diagnostics right after you screwed up, instead of undefined behavior. – Yakk - Adam Nevraumont Feb 27 '14 at 22:26
  • You'll have a problem when `i` and `j` overlap. You need to prevent that. For example, `while (a[i] < pivot && i < j)` and `while (a[j] >= pivot && i < j)`. My C# implementation is similar. See code at the end of http://blog.mischel.com/2011/10/25/when-theory-meets-practice/ – Jim Mischel Feb 27 '14 at 22:41

1 Answers1

2

There are several problems in your code.

  • First, in the function quick_select(), you are comparing pivotIndex with n directly. Since the left isn't always 0, you should compare n with the length of left part which is equal to pivotIndex - left + 1.
  • When n > length, you just callquick_select(a, pivotIndex + 1, right, n) recursively, at this time, it means the N-th element of the whole vector lies in the right part of it, it's the (N - (pivotIndex - left + 1) )-th element of the right part of the vector. The code should be quick_select(a, pivotIndex + 1, right, n - (pivotIndex - left + 1) )(n is ONE-based).
  • It seems you're using Hoare's partitioning algorithm and implement it incorrectly. Even if it works, when HOARE-PARTITION terminates, it returns a value j such that A[p...j] ≤ A[j+1...r], but we want A[p...j-1] ≤ A[j] ≤ A[j+1...r] in the quick_select(). So I use the rand_partition() based on Lomuto's partitioning algorithm I wrote on another post

Here is the fixed quick_select() which returns the N-th(note that n is ONE-based) smallest element of the vector:

int quick_select(vector<int> &a, int left, int right, int n)
{
    if ( left == right ) 
        return a[left];
    int pivotIndex = partition(a, left, right);

    int length = pivotIndex - left + 1;
    if ( length == n) 
        return a[pivotIndex];
    else if ( n < length ) 
        return quick_select(a, left, pivotIndex - 1, n);
    else 
        return quick_select(a, pivotIndex + 1, right, n - length);
}

and this is the rand_partition():

int rand_partition(vector<int> &arr, int start, int end)
{
    int pivot_index = start + rand() % (end - start + 1);
    int pivot = arr[pivot_index];

    swap(arr[pivot_index], arr[end]); // swap random pivot to end.
    pivot_index = end;
    int i = start -1;

    for(int j = start; j <= end - 1; j++)
    {
        if(arr[j] <= pivot)
        {
            i++;
            swap(arr[i], arr[j]);
        }
    }
    swap(arr[i + 1], arr[pivot_index]); // swap back the pivot

    return i + 1;
}

Call srand() first to initialize random number generator so that you can get random-like numbers when calling rand().
Driver program to test above functions:

int main()
{
    int A1[] = {1, 0, 3, 5, 0, 8, 6, 0, 9, 0};
    vector<int> a(A1, A1 + 10);
    cout << "6st order element " << quick_select(a, 0, 9, 6) << endl;
    vector<int> b(A1, A1 + 10); // note that the vector is modified by quick_select()
    cout << "7nd order element " << quick_select(b, 0, 9, 7) << endl;
    vector<int> c(A1, A1 + 10);
    cout << "8rd order element " << quick_select(c, 0, 9, 8) << endl;
    vector<int> d(A1, A1 + 10);
    cout << "9th order element " << quick_select(d, 0, 9, 9) << endl;
    vector<int> e(A1, A1 + 10);
    cout << "10th order element " << quick_select(e, 0, 9, 10) << endl;
}
Community
  • 1
  • 1
jfly
  • 7,715
  • 3
  • 35
  • 65