19

I have to do the following recursive row-by-row operation to obtain z:

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t >= 2) { a[t] = b[t-1] + xb[t] }
    z[t] = rnorm(1, mean = a[t])
    b[t] = a[t] + z[t]

}

return(z)

}

set.seed(1)

n_smpl = 1e6 
ni = 5

id = rep(1:n_smpl, each = ni)

smpl = data.table(id)
smpl[, time := 1:.N, by = id]

a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

smpl[, z := myfun(xb, a, b), by = id]

I would like to obtain a result like this:

      id time a b  xb            z
  1:   1    1 1 1   1    0.3735462
  2:   1    2 1 1   2    2.7470924
  3:   1    3 1 1   3    8.4941848
  4:   1    4 1 1   4   20.9883695
  5:   1    5 1 1   5   46.9767390
 ---                              
496: 100    1 1 1 100    0.3735462
497: 100    2 1 1 200  200.7470924
498: 100    3 1 1 300  701.4941848
499: 100    4 1 1 400 1802.9883695
500: 100    5 1 1 500 4105.9767390

This does work but takes time:

system.time(smpl[, z := myfun(xb, a, b), by = id])
   user  system elapsed 
 33.646   0.994  34.473

I need to make it faster, given the size of my actual data (over 2 million observations). I guess do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b') with by = .(id, time) would be much faster, avoiding the for loop inside myfun. However, I was not sure how I can update b and its lag (probably using shift) when I run this row-by-row operation in data.table. Any suggestions?

jayc
  • 329
  • 1
  • 8
  • I get `Error in rnorm(1, mean = a[t]) : object 'a' not found`. Can you ensure the code works in a fresh R session please. – Matt Dowle Jan 27 '17 at 03:27
  • Sorry for the confusion. I just fixed the code and it now works out. – jayc Jan 27 '17 at 03:30
  • Have you pasted it into a fresh R session? I still get the same error. – Matt Dowle Jan 27 '17 at 03:33
  • My mistake again - works now. Hope you understand - this is my first time to ask a question on this board. I truly appreciate your help! – jayc Jan 27 '17 at 03:37
  • 8
    No worries. Welcome to S.O. Congrats for asking the 5,000th data.table question! – Matt Dowle Jan 27 '17 at 03:43
  • Since you have a working solution but you need a different solution that works faster, you need to please provide a means to generate a large amount of random data so we can reproduce something that takes a long time. If that's already the case then make it clearer; e.g. should we set `n_smpl` to `1e6` and how long does that take for you? – Matt Dowle Jan 27 '17 at 03:51
  • The `set.seed(1)` inside the for loop looks strange. Is that necessary? Normally `set.seed()` is called once at the start of the reproducible example, not many many times. – Matt Dowle Jan 27 '17 at 03:55
  • Just tried to make the example reproducible - it is not necessary. – jayc Jan 27 '17 at 04:05

1 Answers1

40

Great question!

Starting from a fresh R session, showing the demo data with 5 million rows, here's your function from the question and the timing on my laptop. With some comments inline.

require(data.table)   # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

myfun = function (xb, a, b) {

  z = NULL
  # initializes a new length-0 variable

  for (t in 1:length(xb)) {

      if (t >= 2) { a[t] = b[t-1] + xb[t] }
      # if() on every iteration. t==1 could be done before loop

      z[t] = rnorm(1, mean = a[t])
      # z vector is grown by 1 item, each time

      b[t] = a[t] + z[t]
      # assigns to all of b vector when only really b[t-1] is
      # needed on the next iteration 
  }
  return(z)
}

set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
   user  system elapsed 
 19.216   0.004  19.212

smpl
              id time a b      xb            z
      1:       1    1 1 1       1 3.735462e-01
      2:       1    2 1 1       2 3.557190e+00
      3:       1    3 1 1       3 9.095107e+00
      4:       1    4 1 1       4 2.462112e+01
      5:       1    5 1 1       5 5.297647e+01
     ---                                      
4999996: 1000000    1 1 1 1000000 1.618913e+00
4999997: 1000000    2 1 1 2000000 2.000000e+06
4999998: 1000000    3 1 1 3000000 7.000003e+06
4999999: 1000000    4 1 1 4000000 1.800001e+07
5000000: 1000000    5 1 1 5000000 4.100001e+07

So 19.2s is the time to beat. In all these timings, I've run the command 3 times locally to make sure it's a stable timing. The timing variance is insignificant in this task so I'll just report one timing to keep the answer quicker to read.

Tackling the comments inline above in myfun() :

myfun2 = function (xb, a, b) {

  z = numeric(length(xb))
  # allocate size up front rather than growing

  z[1] = rnorm(1, mean=a[1])
  prevb = a[1]+z[1]
  t = 2L
  while(t<=length(xb)) {
    at = prevb + xb[t]
    z[t] = rnorm(1, mean=at)
    prevb = at + z[t]
    t = t+1L
  }
  return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
   user  system elapsed 
 13.212   0.036  13.245 
smpl[,identical(z,z2)]
[1] TRUE

That was quite good (19.2s down to 13.2s) but it's still a for loop at R level. On first glance it can't be vectorized because the rnorm() call depends on the previous value. In fact, it probably can be vectorized by using the property that m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) and calling vectorized rnorm(n=5e6) once rather than 5e6 times. But there'd probably be a cumsum() involved to deal with the groups. So let's not go there as that would probably make the code harder to read and would be specific to this precise problem.

So let's try Rcpp which looks very similar to the style you wrote and is more widely applicable :

require(Rcpp)   # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
  NumericVector z = NumericVector(xb.length());
  z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
  double prevb = a[0]+z[0];
  int t = 1;
  while (t<xb.length()) {
    double at = prevb + xb[t];
    z[t] = R::rnorm(at, 1);
    prevb = at + z[t];
    t++;
  }
  return z;
}')

set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
   user  system elapsed 
  1.800   0.020   1.819 
smpl[,identical(z,z3)]
[1] TRUE

Much better: 19.2s down to 1.8s. But every call to the function calls the first line (NumericVector()) which allocates a new vector as long as the number of rows in the group. That's then filled in and returned which is copied to the final column in the correct place for that group (by :=), only to be released. That allocation and management of all those 1 million small temporary vectors (one for each group) is all a bit convoluted.

Why don't we do the whole column in one go? You've already written it in a for loop style and there's nothing wrong with that. Let's tweak the C function to accept the id column too and add the if for when it reaches a new group.

cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {

  // ** id must be pre-grouped, such as via setkey(DT,id) **

  NumericVector z = NumericVector(id.length());
  int previd = id[0]-1;  // initialize to anything different than id[0]
  for (int i=0; i<id.length(); i++) {
    double prevb;
    if (id[i]!=previd) {
      // first row of new group
      z[i] = R::rnorm(a[i], 1);
      prevb = a[i]+z[i];
      previd = id[i];
    } else {
      // 2nd row of group onwards
      double at = prevb + xb[i];
      z[i] = R::rnorm(at, 1);
      prevb = at + z[i];
    }
  }
  return z;
}')

system.time(setkey(smpl,id))  # ensure grouped by id
   user  system elapsed
  0.028   0.004   0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
   user  system elapsed 
  0.232   0.004   0.237 
smpl[,identical(z,z4)]
[1] TRUE

That's better: 19.2s down to 0.27s.

Matt Dowle
  • 58,872
  • 22
  • 166
  • 224