7

I am running stats::uniroot function on one million rows in data.table. Here is a toy example -

library(data.table)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 50000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

print(system.time(
dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
))

On the above code, the time taken is close to 8 seconds for 50,000 rows.

Is there any faster alternative to the uniroot function which can reduce this time considerably?

Saurabh
  • 1,566
  • 10
  • 23

3 Answers3

5

160 seconds (1e6/5e4 * 8) doesn't sound so bad to me for a million rows (although maybe your real function is much slower than the froot you're using here?). This can be trivially parallelized, running separate chunks on separate cores (see e.g. answers to this question).

How badly do you need extendInt ? I can triple the speed if I make a hacked version of the uniroot() function with only its core functionality, none of the argument-testing logic etc. etc.. However, your speed gain will be much less impressive if your target function is much slower than the example you've given here; if that's the case, you should focus on speeding up your target function (I tried recoding your froot in C++ via Rcpp, but it doesn't really help in this case — the function is sufficiently trivial that the function-calling overhead takes most of the time ...)

I did this with only 5000 rows, for ease of benchmarking:

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

Minimal function:

uu <- function(f, lower, upper, tol = 1e-8, maxiter =1000L, ...) {
  f.lower <- f(lower, ...)
  f.upper <- f(upper, ...)
  val <- .External2(stats:::C_zeroin2, function(arg) f(arg, ...),
                    lower, upper, f.lower, f.upper, tol, as.integer(maxiter))
  return(val[1])
}

Check that we get the same results:

identical(uniroot(froot, u = 3.242, a=0.5, b=1, interval = c(0.01,100))$root,
          uu(froot, u = 3.242, a=0.5, b=1, lower = 0.01, upper = 100))
## TRUE

Benchmarking package; wrap evaluations in functions for compactness

library(rbenchmark)
f1 <- function() {
  dt[, c := uniroot(froot_cpp, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
}
f2 <- function() {
  dt[, c := uu(froot, u=u, a=a, b=b, lower = 0.01, upper = 100), by = u]
}
bb <- benchmark(f1(), f2(), 
    columns =c("test", "replications", "elapsed", "relative"))

Results:

  test replications elapsed relative
1 f1()          100  34.616    3.074
2 f2()          100  11.261    1.000
Ben Bolker
  • 211,554
  • 25
  • 370
  • 453
  • Thanks Ben, the minimal function improved the performance of my code by 20%, which is a huge improvement. – Saurabh Nov 06 '21 at 01:28
  • Great. In that case you will probably get a lot more mileage out of speeding up your target function though ... (and splitting/parallelizing will buy you about as much improvement as you have cores, if you have the memory to go with them). – Ben Bolker Nov 06 '21 at 01:29
  • I have 8 cores and ```data.table``` uses 4 cores by default. Increasing the cores from 4 to 8 did improve the performance by 2% but that was not worth it as I want to keep a few cores available for other tasks. Using function```mcapply``` also doest not improve the performance much. I will try to optimize the function. – Saurabh Nov 06 '21 at 01:33
  • ```rootSolve::uniroot.all``` function improved performance over the ```stats::uniroot``` function by 15% but the accuracy is compromised. Your minimal function is still faster and accurate. – Saurabh Nov 06 '21 at 01:41
3

Note that the inverse of the function shown can be computed explicitly as

f2 <- function(x) (b^a * x / b)^(1/a)
a <- 1/2
b <- 1
all.equal(f(.5), f2(.5))  # f defined below using uniroot
## [1] TRUE

however, assuming that in reality you have a more complex function we can use Chebyshev approximation to get a close approximation to it. Note that a and b are constants in the question and so we also assume that to be the case below, i.e. f uses the constants a and b set in the global environment. The code below runs nearly 100x faster than the code in the question on the benchmark with a 9th degree polynomial and is within 1e-4 of the answer given by uniroot. Use a higher degree if you need even more accuracy.

library(data.table)
library(pracma)
set.seed(123)

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

dt2 <- copy(dt)
f <- function(u) {
  uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root
}

library(microbenchmark)
microbenchmark(times = 10,
  orig = dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u],
  cheb = dt2[, c := chebApprox(u, Vectorize(f), min(u), max(u), 9)]
)
## Unit: milliseconds
##  expr      min       lq      mean    median       uq      max neval cld
##  orig 943.5323 948.9321 961.00361 958.91970 972.6308 982.0060    10   b
##  cheb   9.3752   9.7513  10.67386  10.02555  10.3411  16.9475    10  a 

max(abs(dt$c - dt2$c))
## [1] 8.081021e-05
G. Grothendieck
  • 254,981
  • 17
  • 203
  • 341
  • Thanks, Grothendieck. Unfortunately, variables ```a``` and ```b``` are not constants in my original function. Is there a way to pass them to the chebApprox function as you did in ```uniroot``` function? – Saurabh Nov 06 '21 at 15:57
  • 1
    you may be able to mess around with the `vectorize.args` argument of `Vectorize` – Ben Bolker Nov 06 '21 at 16:10
  • Since ```a``` and ```b``` are not constants, I tried the following but I am getting ```NaN``` values in variable ```c```. ```vec_f <- Vectorize(f, vectorize.args = c("u", "a", "b")) dt2[, c := chebApprox(x = u, fun = function(x) vec_f(u = x, a, b), a = min(u), b = max(u), n = 9), by = u]```. I must be doing something wrong, can you please point out? – Saurabh Nov 06 '21 at 20:13
  • If the problem cannot be reduced to 1 or more 1d problems then chebApprox is not applicable. For example, there were a small number of a, b combinations then chebApprox could be applied to each one separately. If there were special features of your function then there might also be other approaches but we don't have any information on this. – G. Grothendieck Nov 06 '21 at 20:45
2

There are great answers to the exact question, but a couple of notes on general R practices.

Using by when order doesn't matter

In the OP, we are using by = u so that each row is run one at a time. This is inefficient! data.table will order u, determine groupings, and since they are real very random numbers, end up with as many groupings as rows.

Instead, we can use Map() or mapply() to iterate through the rows which will improve performance. Note, it's unclear whether a and b actually vary by row - if they truly are constant, we would likely want to take them out of the data.table and pass them as constants.

uniroot2 = function(...) uniroot(...)$root ## helper function
dt[, c2 := mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes'))]

## for n = 5000

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
##  1 OP           1.17s   1.17s     0.851     170KB     2.55     1     3      1.17s
##  2 no_by      857.2ms 857.2ms     1.17      214KB     3.50     1     3    857.2ms
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

Note, once we have it set up in mapply, it is trivial to use future.apply::future_mapply() to parallelize our call. This is 2.5 times faster than the no_by example above on my laptop.

library(future.apply)
plan(multisession)
dt[, c3 := future_mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes')
                  , future.globals = "cumhaz")] ## see next section for how we could remove this

Function calls take time

In your example, you define two functions as:

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

When performance is an issue and it is trivial to simplify, you may want to simplify.

froot2 = function(x, u, a, b) b * (x / b) ^ a - u

Over a million of loops, the additional call to cumhaz() adds up:

x = 2.5; u = 1.5; a = 0.5; b = 1 
bench::mark(froot_rep = for (i in 1:1e6) {froot(x=x, u=u, a=a, b=b)},
            froot2_rep = for (i in 1:1e6) {froot2(x=x, u=u, a=a, b=b)})

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
## 1 froot_rep    4.74s   4.74s     0.211    13.8KB     3.38     1    16      4.74s
## 2 froot2_rep   3.17s   3.17s     0.315    13.8KB     2.84     1     9      3.17s
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

Since uniroot would further increase the calls with a default max iterations of 1,000! That means cumhaz() costs us somewhere between 1.5s and 1,500s during the optimization. And as @G. Grothendieck pointed out, sometimes we can actually directly solve and used direct vectorized methods instead of relying on uniroot or optimize.

Cole
  • 11,130
  • 1
  • 9
  • 24
  • Thanks, Cole for the suggestion. It surely does not make sense to use grouping for row iteration. I will try your suggesiton. – Saurabh Nov 06 '21 at 15:54
  • In my original function, variables ```a``` and ```b``` are not constants. I applied your suggestions to my code and here are the results. 1.) Merging the smaller functions together improved the performance by around 5%, thanks for that. 2.) Using ```mapply``` instead of grouping ```by``` have no effect on the performance. 3.) Using parallelism decreased the performance by 5%. – Saurabh Nov 06 '21 at 17:47
  • Thanks for sharing. Somewhat surprising on findings 2 and 3. If you want to expand on your original post, it’d be interesting to see what the real implementation is. As for the `by`, then you may want to do `by = seq_len(nrow(dt))`. This would not help much but is more efficient. – Cole Nov 06 '21 at 17:50
  • Using ```by = seq_len(nrow(dt))``` bumped up the performance by around 2% probably due to reduced overhead of ordering. Thanks! – Saurabh Nov 06 '21 at 18:00