Problem
I want to sort a matrix by row, but return the rank of each element.
Example
Values Rank
------------- --------------
[5, 4, 1, 9] [2, 1, 0, 3]
[1, 4, 3, 2] --> [0, 3, 2, 1]
[2, 4, 2, 0] [1, 3, 2, 0]
Attempt
I've these two examples:
Ranking rows of a matrix
Sort rows of a matrix
The first shows how to use an index vector and a permutation iterator to return the index of the sorted values. The second shows how to use the "back-to-back" method to sort a matrix by row. (Sort by key 2x). But I can't figure out how to combine these two ideas.
I tried using a zip_iterator to combine values and indexes into a tuple, and then do the back to back method, but I can't do a sort-by-key on ziped tuples.
I also tried using the back-to-back sort, and then indexing the values, but then the index is just the already sorted values, so the index is always [0, 1, 2, 3] for each row of the matrix.
Code
#include <iostream>
#include <iomanip>
#include <fstream>
#include <thrust/device_vector.h>
#include <thrust/device_ptr.h>
#include <thrust/host_vector.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/generate.h>
#include <thrust/equal.h>
#include <thrust/sequence.h>
#include <thrust/for_each.h>
#include <iostream>
#include <stdlib.h>
using namespace std;
#define NSORTS 5
#define DSIZE 4
// -------------------
// Print
// -------------------
template <class Vector>
void print(std::string name, Vector toPrint)
{
cout << setw(13) << name << " :: ";
int i = 0;
for (auto x : toPrint)
{
i++;
std::cout << setw(2) << x << " ";
if (!(i%4))
cout << " ";
}
std::cout << std::endl;
}
// ---------------------
// Print Title
// ---------------------
void print_title(const std::string title)
{
cout << "\n\n";
cout << "-------------------\n";
cout << " " << title << "\n";
cout << "-------------------\n";
}
// ---------------------
// My Mod
// ---------------------
int my_mod_start = 0;
int my_mod(){
return (my_mod_start++)/DSIZE;
}
// ------------------
// Clamp
// ------------------
struct clamp
{
template <typename T>
__host__ __device__
T operator()(T data){
if (data <= 0) return 0;
return 1;}
};
int main()
{
// Initialize
thrust::host_vector<int> h_data(DSIZE * NSORTS);
thrust::generate(h_data.begin(), h_data.end(), rand);
thrust::transform(h_data.begin(), h_data.end(), h_data.begin(), thrust::placeholders::_1 % 10);
int size = DSIZE * NSORTS;
// Device Vectors
thrust::device_vector<int> d_data = h_data;
thrust::device_vector<int> d_idx(size);
thrust::device_vector<int> d_result(size);
thrust::sequence(d_idx.begin(), d_idx.end());
// Segments
thrust::host_vector<int> h_segments(size);
thrust::generate(h_segments.begin(), h_segments.end(), my_mod);
thrust::device_vector<int> d_segments = h_segments;
print_title("Generate");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
// Sort 1
thrust::stable_sort_by_key(d_data.begin(), d_data.end(), d_segments.begin());
print_title("Sort 1");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
// Sort 2
thrust::stable_sort_by_key(d_segments.begin(), d_segments.end(), d_data.begin());
print_title("Sort 2");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
// Adjacent Difference
thrust::device_vector<int> d_diff(size);
thrust::adjacent_difference(d_data.begin(), d_data.end(), d_diff.begin());
d_diff[0] = 0;
print_title("Adj Diff");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
print("Difference", d_diff);
// Transform
thrust::transform(d_diff.begin(), d_diff.end(), d_diff.begin(), clamp());
print_title("Transform");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
print("Difference", d_diff);
// Inclusive Scan
thrust::inclusive_scan_by_key(d_segments.begin(), d_segments.end(), d_diff.begin(), d_diff.begin());
print_title("Inclusive");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
print("Difference", d_diff);
// Results
thrust::copy(d_diff.begin(), d_diff.end(), thrust::make_permutation_iterator(d_result.begin(), d_idx.begin()));
print_title("Results");
print("Data", d_data);
print("Index", d_idx);
print("Segments", d_segments);
print("Difference", d_diff);
print("Results", d_result);
}
Edit -- example rank matrix was wrong