3

I am stuck trying to find the indices of elements in a vector x whose elements are also in another vector vals using Rcpp Armadillo. Both x and vals are of type arma::uvec.

In R, this would be straightforward:

x <- c(1,1,1,4,2,4,4)
vals <- c(1,4)
which(v %in% vals)

I've scanned the Armadillo docs and find() was my obvious first try; but it didn't work, since vals is a vector. I've also tried intersect() but it returns only the first unique indices.

What would be a good/efficient way to do this using Armadillo? Do I have to iterate through the elements in vals using find()?

Adrian Mole
  • 49,934
  • 160
  • 51
  • 83
Econ21
  • 33
  • 3

2 Answers2

5

A quick dirty way:

Rcpp::cppFunction("
  arma::uvec ind(arma::uvec x, arma::uvec y){
   arma::vec a(x.size(), arma::fill::zeros);
   for (auto i:y) a = a +  (x==i);
   return arma::find(a) + 1;
  }
 ", 'RcppArmadillo')

c(ind(v, vals))
[1] 1 2 3 4 6 7
Onyambu
  • 67,392
  • 3
  • 24
  • 53
0

For completeness, I came up with this solution in the meantime:

arma::uvec getIndex(arma::uvec x, arma::uvec y) {
  
  int i, j, k = 0, n = y.size();
  arma::uvec tmp(n);
  
  for (i = 0; i < n; i++) {
    arma::uvec tmpID = arma::find(x == y(i));
    tmp(i) = tmpID.size();
  }
  
  arma::uvec out(sum(tmp));
  
  for (i = 0; i < n; i++) {
    arma::uvec id = arma::find(x == y(i));
    for (j = 0; j < id.size(); j++) {
      out(j+k) = id(j);
    }
    k += tmp(i);
  }
  return out;
}
Econ21
  • 33
  • 3