5

I have a data table with probabilities for a discrete distribution stored in columns.

For example, dt <- data.table(p1 = c(0.5, 0.25, 0.1), p2 = c(0.25, 0.5, 0.1), p3 = c(0.25, 0.25, 0.8))

I'd like to create a new column of a random variable sampled using the probabilities in the same row. In data.table syntax I imagine it working like this:

dt[, sample := sample(1:3, 1, prob = c(p1, p2, p3))]

If there were a 'psample' function similar to 'pmin' and 'pmax' this would work. I was able to make this work using apply, the downside is that with my real data set this takes longer than I would like. Is there a way to make this work using data.table? The apply solution is given below.

dt[, sample := apply(dt, 1, function(x) sample(1:3, 1, prob = x[c('p1', 'p2', 'p3')]))]
jay.sf
  • 60,139
  • 8
  • 53
  • 110
  • 2
    Related: [Efficiently apply sample() in R](https://stackoverflow.com/questions/53187985/efficiently-apply-sample-in-r). – Henrik Jul 16 '22 at 22:06
  • @Henrik Nice. Throws an error, but it's solved [there](https://stackoverflow.com/a/59357190/6574038). – jay.sf Jul 17 '22 at 08:46
  • Not an answer, but for the record: if you fail to vectorize your function, common alternatives to `apply` are (1) `by = 1:nrow(dt)`, or (2) melt to long format. Described e.g. here: [Efficient row-wise operations on a data.table](https://stackoverflow.com/questions/7885147/efficient-row-wise-operations-on-a-data-table); [How to do row wise operations on .SD columns in data.table](https://stackoverflow.com/questions/33353036/how-to-do-row-wise-operations-on-sd-columns-in-data-table), posts that you should have found, even with a very poor google-fu - "R data.table rowwise" ;) – Henrik Jul 17 '22 at 11:54
  • @Henrik Exactly. With my [`psampv`](https://stackoverflow.com/a/73001876/6574038) I actually was inspired from the `pmin` solution of your first link. – jay.sf Jul 17 '22 at 12:09

3 Answers3

4

If you are choosing from 1:n you could use sampl.int which is faster. Also applying on a matrix is faster. Putting both in a function psamp is even faster.

So, try this (I added dt[, 1:3] so that it won't fail once the column is added):

psamp <- function(x) sample.int(n=3, size=1, prob=x)
dt[, sample :=apply(as.matrix(dt[, 1:3]), 1, psamp)]

To get somewhat rid of the apply we could Vectorize psamp and use do.call. Additionally—as @IRTFM suggests in his answer—we should make use of the .SD symbol.

psampv <- Vectorize(function(p1, p2, p3) sample.int(n=3, size=1, replace=TRUE, prob=c(p1, p2, p3)))
dt[, sample := do.call(psampv, .SD), .SDcols=c('p1','p2','p3')]

To improve performance by even more than an order of magnitude, as suggested by @Henrik in comments we may use Rcpp. I have slightly adapted the code from this answer and use the new Rcpp::sample, which kindly gives identical results to base::sample with the same set.seed.

#include <Rcpp.h>
// [[Rcpp::export]]
Rcpp::IntegerVector sample_matrix1(Rcpp::NumericMatrix x, Rcpp::IntegerVector choice_set) {
  int n = x.nrow();
  Rcpp::IntegerVector result(n);
  for (int i = 0; i < n; ++i) {
    Rcpp::NumericVector z(x(i, Rcpp::_));
    result[i] = Rcpp::sample(choice_set, 1, false, z)[0];
  }
  return result;
}

Rcpp::sourceCpp("sample_matrix1.cpp")

dt[, sample := sample_matrix1(as.matrix(.SD), 1:3), .SDcols=c('p1','p2','p3')] 

Benchmark, 100k*100 repetitions each:

Unit: milliseconds
          expr        min         lq       mean     median         uq       max neval cld
      psamp_:= 1195.16708 1259.06558 1327.19581 1311.17878 1349.98905 1515.1187   100   b
     psamp_.SD 1225.90467 1257.37766 1318.74885 1289.27571 1335.07736 1522.3423   100   b
     psamp_set 1181.44985 1256.73204 1320.29317 1301.75657 1335.22009 1491.3870   100   b
 psamp_do.call 1181.93117 1251.45863 1316.23306 1285.85710 1337.06674 1476.8023   100   b
          rcpp   60.73652   67.15291   72.76073   70.47052   73.91629  127.8278   100  a 
jay.sf
  • 60,139
  • 8
  • 53
  • 110
2

I think the proper data.table approach would be to use the .SD facilities:

dt2 <- rbind(dt,dt,dt,dt)
psamp <- function(x) sample.int(n=3, size=1, prob=x) # from jay.sf

dt2[, sample :=apply(.SD, 1, psamp), .SDcols=c('p1','p2','p3')]
> dt2
      p1   p2   p3 sample
 1: 0.50 0.25 0.25      2
 2: 0.25 0.50 0.25      1
 3: 0.10 0.10 0.80      2
 4: 0.50 0.25 0.25      3
 5: 0.25 0.50 0.25      2
 6: 0.10 0.10 0.80      3
 7: 0.50 0.25 0.25      3
 8: 0.25 0.50 0.25      2
 9: 0.10 0.10 0.80      3
10: 0.50 0.25 0.25      1
11: 0.25 0.50 0.25      2
12: 0.10 0.10 0.80      3

A note on style: it's better to refrain from naming R objects with strings that are also names of R functions, such as df (density of the F distribution), dt (density of the t-distribution), data (method to load a canned dataset).

IRTFM
  • 258,963
  • 21
  • 364
  • 487
  • I do like this approach because it allows you to use different probability columns without defining a new function. – Gregory Apartment Jul 17 '22 at 03:18
  • Of course @IRTFM, thanks for sharing your wisdom! I kept working on it, trying to get [rid](https://stackoverflow.com/a/73001876/6574038) of the apply. In the end, `Rcpp` beats us all, though. – jay.sf Jul 17 '22 at 10:21
  • Oh yeah, I would never want to challenge Rcpp. – IRTFM Jul 17 '22 at 18:08
2

I think this is also an option, and might be quite fast? Perhaps @jay.sf can let us know, as it also uses psamp (thanks @jay.sf)

set(dt,j="sample",value=apply(dt,1,psamp))
langtang
  • 22,248
  • 1
  • 12
  • 27
  • 1
    Isn't `set` equivalent to `:=`? I took it up in the benchmark, appears to be slightly faster. – jay.sf Jul 17 '22 at 11:50