1

As a follow-up to my previous question, I'm interested in improving the performance of the existing recursive sampling function.

By recursive sampling I mean randomly choosing up to n unique unexposed IDs for a given exposed ID, and the randomly choosing up to n unique unexposed IDs from the remaining unexposed IDs for another exposed ID. If there are no remaining unexposed IDs for a given exposed ID, then the exposed ID is left out.

The original function is as follows:

recursive_sample <- function(data, n) {
 
 groups <- unique(data[["exposed"]])
 out <- data.frame(exposed = character(), unexposed = character())
 
 for (group in groups) {
  
  chosen <- data %>%
   filter(exposed == group,
          !unexposed %in% out$unexposed) %>%
   group_by(unexposed) %>%
   slice(1) %>%
   ungroup() %>%
   sample_n(size = min(n, nrow(.))) 
  
  out <- rbind(out, chosen)
  
 }
 
 out
 
}

I was able to create a more efficient one as follows:

recursive_sample2 <- function(data, n) {
 
 groups <- unique(data[["exposed"]])
 out <- tibble(exposed = integer(), unexposed = integer())
 
 for (group in groups) {
  
  chosen <- data %>%
   filter(exposed == group,
          !unexposed %in% out$unexposed) %>%
   filter(!duplicated(unexposed)) %>%
   sample_n(size = min(n, nrow(.))) 
  
  out <- bind_rows(out, chosen)
  
 }
 
 out
 
}

Sample data and bechmarking:

set.seed(123)
df <- tibble(exposed = rep(1:100, each = 100),
             unexposed = sample(1:7000, 10000, replace = TRUE))

microbenchmark(f1 = recursive_sample(df, 5),
               f2 = recursive_sample2(df, 5),
               times = 10)

Unit: milliseconds
 expr       min        lq      mean    median        uq      max neval cld
   f1 1307.7198 1316.5276 1379.0533 1371.3952 1416.6360 1540.955    10   b
   f2  839.0086  865.2547  914.8327  901.2288  970.9518 1036.170    10  a 

However, for my actual dataset, I would need an even more efficient (i.e., quicker) function. Any ideas for a more efficient version, whether in data.table, involving parallelisation or other approaches are welcome.

ThomasIsCoding
  • 96,636
  • 9
  • 24
  • 81
tmfmnk
  • 38,881
  • 4
  • 47
  • 67
  • 3
    A reproducible example would be helpful. – s_baldur Jun 05 '23 at 11:02
  • 4
    i don't see any recursion here? can you explain what the function is supposed to do? – George Savva Jun 05 '23 at 11:11
  • @ George Savva recursive might not be the perfect way to describe this process, I agree. In essence, the function should sample up to n unique unexposed IDs for each exposed ID, but in a step-wise fashion. This means sampling up to n unique unexposed IDs for one exposed ID and then sampling up to n unexposed IDs from the remaining unexposed IDs for another exposed ID. If there are no remaining unexposed IDs for a given exposed ID, then the exposed ID is left out. – tmfmnk Jun 05 '23 at 14:24
  • I guess "Dynamic Programming" (DP) seems more accurate for your question, since you need dynamically adapt your sampling according to your previous actions. – ThomasIsCoding Jun 05 '23 at 21:08
  • I suggest you have a self-contained data example and expected output as well, and rephrase "recursive" to something else to avoid misunderstand. – ThomasIsCoding Jun 12 '23 at 13:48

3 Answers3

2

Update, with More Improvement

A more concise solution might be using Reduce + split, where we shuffle the rows of data first and then we samples by groups iteratively

ftic <- function(data, n) {
    Reduce(
        \(x, y) {
            rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
        },
        split(data[sample(1:nrow(data)), ], ~exposed)
    )
}

and below is a tougher pressure test, i.e., data of 1e6 rows, where the approaches include:

ftmfmnk <- function(data, n) {
    groups <- unique(data[["exposed"]])
    out <- tibble(exposed = integer(), unexposed = integer())

    for (group in groups) {
        chosen <- data %>%
            filter(
                exposed == group,
                !unexposed %in% out$unexposed
            ) %>%
            filter(!duplicated(unexposed)) %>%
            sample_n(size = min(n, nrow(.)))

        out <- bind_rows(out, chosen)
    }

    out
}

fminem <- function(data, n) {
    groups <- unique(data[["exposed"]])
    # working on vectors is faster
    id <- 1:nrow(data)
    i <- vector("integer")
    unexposed2 <- vector(class(data$unexposed))
    ex <- data$exposed
    ux <- data$unexposed

    for (group in groups) {
        f1 <- ex == group # first filter
        f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
        id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
        # and select necesary row ids
        is <- sample(id3, size = min(length(id3), n)) # sample row ids
        i <- c(i, is) # add to list
        unexposed2 <- ux[i] # resave unexposed2
    }
    out <- data[i, ] # only one data.frame subset
    out$id <- NULL
    out
}

ftic <- function(data, n) {
    Reduce(
        \(x, y) {
            rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
        },
        split(data[sample(1:nrow(data)), ], ~exposed)
    )
}

The benchmarking is as below

set.seed(123)
df <- tibble(
    exposed = rep(1:1000, each = 1000),
    unexposed = sample(1:70000, 1000000, replace = TRUE)
)

mbm <- microbenchmark(
    tmfmnk = ftmfmnk(df, 5),
    minem = fminem(df, 5),
    tic = ftic(df, 5),
    times = 10
)

boxplot(mbm)

and we will see that

> mbm
Unit: milliseconds
   expr        min         lq       mean     median         uq        max neval
 tmfmnk 36809.9563 44276.3545 43780.8407 44897.2661 46175.1031 46948.8906    10
  minem  5361.2796  5932.7752  5923.8811  6010.7775  6047.3716  6233.2919    10
    tic   504.5749   519.5997   641.7935   607.2825   729.4545   868.1283    10

enter image description here


Previous Naïve Approach

I don't have any advanced technique here, but just a dynamic programming scheme with for loops, and I believe there must be more performant approaches than mine

dp <- function(df, n) {
    d <- table(df)
    out <- list()
    rnm <- row.names(d)
    cnm <- colnames(d)
    for (i in 1:nrow(d)) {
        v <- which(d[i, ] > 0)
        l <- length(v)
        idx <- v[sample(l, min(l, n))]
        out[[i]] <- data.frame(exposed = rnm[i], unexposed = cnm[idx])
        d[, idx] <- 0
    }
    do.call(rbind, out)
}

and the benchmarking

set.seed(123)
df <- tibble(
    exposed = rep(1:100, each = 100),
    unexposed = sample(1:7000, 10000, replace = TRUE)
)

mbm <- microbenchmark(
    f1 = recursive_sample(df, 5),
    f2 = recursive_sample2(df, 5),
    f3 = dp(df, 5),
    times = 10
)

boxplot(mbm)

shows

> mbm
Unit: milliseconds
 expr       min        lq      mean    median        uq       max neval
   f1 1271.0135 1302.4310 1449.2193 1326.7630 1686.4329 1888.4549    10
   f2  507.9350  516.8854  617.0313  559.0422  706.4300  801.0124    10
   f3  212.8944  247.0066  278.1792  271.9010  309.7377  354.4320    10

enter image description here

Also, to check the result res <- dp(df, 5), we can use

> table(res$exposed)

  1  10 100  11  12  13  14  15  16  17  18  19   2  20  21  22  23  24  25  26
  5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5
 27  28  29   3  30  31  32  33  34  35  36  37  38  39   4  40  41  42  43  44
  5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5
 45  46  47  48  49   5  50  51  52  53  54  55  56  57  58  59   6  60  61  62
  5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5
 63  64  65  66  67  68  69   7  70  71  72  73  74  75  76  77  78  79   8  80
  5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5
 81  82  83  84  85  86  87  88  89   9  90  91  92  93  94  95  96  97  98  99
  5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5   5

> anyDuplicated(res$unexposed)
[1] 0
ThomasIsCoding
  • 96,636
  • 9
  • 24
  • 81
2

Working on vectors is much faster:

recursive_sample3 <- function(data, n) {
  groups <- unique(data[["exposed"]])
  # working on vectors is faster
  id <- 1:nrow(data)
  i <- vector('integer')
  unexposed2 <- vector(class(data$unexposed))
  ex <- data$exposed
  ux <- data$unexposed
  
  for (group in groups) {
    f1 <- ex == group # first filter
    f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
    id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
    # and select necesary row ids
    is <- sample(id3, size = min(length(id3), n)) # sample row ids
    i <- c(i, is) # add to list
    unexposed2 <- ux[i] # resave unexposed2
  }
  out <- data[i, ] # only one data.frame subset
  out$id <- NULL
  out
}

benchmarks:

microbenchmark(f1 = recursive_sample(df, 5),
               f2 = recursive_sample2(df, 5),
               f3 = recursive_sample3(df, 5),
               times = 3)
# Unit: milliseconds
# expr       min        lq        mean    median         uq       max neval cld
#   f1 1399.8988 1407.1939 1422.008133 1414.4889 1433.06280 1451.6367     3 a  
#   f2  667.0813  673.7229  678.106400  680.3645  683.61895  686.8734     3  b 
#   f3    6.2399    6.2625    9.531267    6.2851   11.17695   16.0688     3   c

Iterating on recursive_sample3 & incorporating concerns of sample:

f_minem <- function(data, n) {
  i <- vector('integer')
  unexposed2 <- vector(class(data$unexposed))
  ux <- data$unexposed
  exl <- split(1:nrow(data), data$exposed)
  for (ii in exl) {
    f2 <- !ux[ii] %in% unexposed2
    f12 <- ii[f2]
    dn <- !duplicated(ux[f12])
    id3 <- f12[dn]
    is <- id3[sample.int(min(length(id3), n))]
    i <- c(i, is)
    unexposed2 <- ux[i]
  }
  out <- data[i, ]
  out
}

benchmarks nr2:

microbenchmark::microbenchmark(
  recursive_sample3 = recursive_sample3(df, 5L),
  recursive_sample4 = recursive_sample4(setDT(df), 5L),
  f_minem = f_minem(df, 5L),
  setup = {df <- copy(data)}
  , times = 10
)
# Unit: milliseconds
#              expr    min     lq    mean  median      uq     max neval cld
# recursive_sample3 6.2102 6.2974 9.63296 6.43245 16.3367 17.0746    10  a 
# recursive_sample4 3.5145 3.6249 3.67077 3.67075  3.7513  3.7970    10   b
#           f_minem 2.1705 2.1920 2.27510 2.23215  2.3784  2.4585    10   b
minem
  • 3,640
  • 2
  • 15
  • 29
  • This solution is incredibly fast on my actual data. Thanks a lot for this :) – tmfmnk Jun 05 '23 at 14:31
  • This could give erroneous results if `id3` is of length 1 due to the behavior of `sample` when the first argument has length 1. – jblood94 Jun 05 '23 at 14:37
2

A data.table solution that keeps a running list of sampled values that are used in setdiff (or %!in% from collapse):

library(data.table)
library(collapse) # for %!in%

recursive_sample4 <- function(data, n) {
  sampled <- vector("list", uniqueN(data$exposed))
  data[
    ,.(
      unexposed = {
        x <- setdiff(unexposed, unlist(sampled))
        sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
      }
    ), exposed
  ]
}

recursive_sample5 <- function(data, n) {
  sampled <- vector("list", uniqueN(data$exposed))
  data[
    ,.(
      unexposed = {
        x <- unexposed[unexposed %!in% unlist(sampled)]
        sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
      }
    ), exposed
  ]
}

Timing (including recursive_sample3 by @minem):

data <- copy(df)

microbenchmark::microbenchmark(
  recursive_sample2 = recursive_sample2(df, 5L),
  recursive_sample3 = recursive_sample3(df, 5L),
  recursive_sample4 = recursive_sample4(setDT(df), 5L),
  recursive_sample5 = recursive_sample5(setDT(df), 5L),
  setup = {df <- copy(data)}
)
#> Unit: milliseconds
#>               expr      min        lq       mean    median        uq      max neval
#>  recursive_sample2 416.5425 427.38700 452.520780 436.58280 459.79430 614.6392   100
#>  recursive_sample3   4.5211   5.16330   6.765060   5.79820   6.95425  14.0693   100
#>  recursive_sample4   3.2038   3.57650   4.676284   4.41120   4.90855  11.6975   100
#>  recursive_sample5   2.2327   2.58255   3.384131   3.27405   3.93265   8.7091   100

Note that recursive_sample3 can give erroneous results due to the behavior of sample when the first argument is of length 1:

set.seed(123)
df <- tibble(exposed = rep(1:100, each = 100),
             unexposed = sample(1:700, 10000, replace = TRUE))
nrow(recursive_sample3(df, 10L))
#> [1] 704
jblood94
  • 10,340
  • 1
  • 10
  • 15
  • You should be aware that the set operations are relatively slow, so you should be able to achieve some performance gain if you replace `setdiff` by some expression with `%in%`. Anyway, nice `data.table` approach, +1! – ThomasIsCoding Jun 05 '23 at 21:14
  • 1
    Thanks for the prod. That matches with my somewhat vague memory. I updated with `%!in%` from collapse for a bit of a speedup. – jblood94 Jun 05 '23 at 21:34