2

Let $A$ be an $m \times n$ matrix, not necessarily symmetric, and $D$ be a diagonal $n \times n$ matrix. In my problem, $m < n$ but both $m$ and $n$ have the potential to be quite large, so the computation of $ADA^T$ can take a few seconds as the dimensions get really high. The program I'm writing will carry out this computation numerous times for different variations of $A$ and $D$, so I want to find a way to speed it up, as that could save minutes over the course of the script.

Currently, I'm just doing tcrossprod(A %*% D, A). But I feel like there should be a quicker way to compute it since $D$ is a diagonal matrix. Algebraically or computationally, does anyone see any good shortcuts here?

androsrj
  • 27
  • 4
  • Maybe [this comment](https://stackoverflow.com/questions/66576386/fast-matrix-multiplication-of-a-matrix-a-and-its-transpose-in-r-using-rcpp#comment117691934_66576386) and the next one are relevant. The 2nd comment links to a question whose [2nd answer](https://stackoverflow.com/a/63092829/8245406) might also be reçevant. – Rui Barradas May 01 '23 at 19:06
  • 1
    A separate question, but would these be large *sparse* matrices (in which case you can gain enormously by using the `Matrix` package's sparse-matrix machinery)? – Ben Bolker May 01 '23 at 19:06
  • 1
    We can replace one of the matrix multiplies with an ordinary multiply of the diagonal. `A %*% (diag(D) * t(A))` – G. Grothendieck May 01 '23 at 19:22
  • @RuiBarradas, nice catch, but at a quick look it doesn't seem like it will be trivial to get the DSYRK2 call from BLAS embedded in R code ... ? – Ben Bolker May 01 '23 at 19:26
  • I think the diag(D) * t(A) may be what I'm looking for! I'll try it out but it definitely seems like it would be faster. Thank you all – androsrj May 01 '23 at 19:51

2 Answers2

2

Just to enumerate the different paths taken by the C code underlying tcrossprod ...

For dense A, of implicit class matrix:

If D is non-negative, then you can get BLAS routine DSYRK, which takes advantage of symmetry, with:

tcrossprod(A * rep(sqrt(diag(D, names = FALSE)), each = nrow(A)))

Otherwise, you are stuck with BLAS routine DGEMM:

tcrossprod(A * rep(diag(D, names = FALSE), each = nrow(A)), A)

You can experiment with external BLAS if you know that A and D are finite, not containing IEEE special values Inf or NaN (including R's NA_real_). R will not use the external BLAS if it detects that either matrix factor is non-finite.

For sparse A, of formal class dgCMatrix:

If D is non-negative, then you can get CHOLMOD routine cholmod_aat, which takes advantage of symmetry, with:

A@x <- A@x * rep.int(sqrt(diag(D, names = FALSE)), A@p[-1L] - A@p[-ncol(A)])
tcrossprod(A)

Otherwise, you are stuck with CHOLMOD routines cholmod_transpose and cholmod_ssmult:

AD <- A
AD@x <- A@x * rep.int(diag(D, names = FALSE), A@p[-1L] - A@p[-ncol(A)])
tcrossprod(AD, A)

In any case, because D is diagonal, you should only need one matrix multiply.

Mikael Jagan
  • 9,012
  • 2
  • 17
  • 48
2

Here is a benchmarking example for several possible implementations, where the diagonal entries of D might be negative (thus as.complex was used in f4 and f5)

set.seed(1)
m <- 100
n <- 500
A <- matrix(runif(m * n), m, n)
D <- diag(rnorm(n))


microbenchmark(
  f0 = A %*% D %*% t(A),
  f1 = A %*% (diag(D) * t(A)),
  f2 = tcrossprod(A %*% D, A),
  f3 = tcrossprod(A, tcrossprod(A, D)),
  f4 = Re(tcrossprod(A %*% `dim<-`(sqrt(as.complex(D)), dim(D)))),
  f5 = Re(tcrossprod(A * sqrt(as.complex(diag(D)))[col(A)])),
  check = "equal"
)

and we will see

Unit: milliseconds
 expr     min       lq      mean   median       uq      max neval
   f0 14.2376 15.10425 15.411022 15.30060 15.48965  18.6084   100
   f1  2.5351  2.69190  2.837047  2.73960  2.80545   6.5885   100
   f2 14.1424 14.80185 15.114680 14.95035 15.10380  20.3685   100
   f3 13.5691 14.38270 15.200161 14.82995 15.45375  23.1657   100
   f4 32.7007 34.29920 37.068124 34.99865 36.99130 118.8519   100
   f5  5.4284  5.61460  5.989452  5.83605  5.92440  10.4077   100
ThomasIsCoding
  • 96,636
  • 9
  • 24
  • 81