1. I want to multiply a batch of matrices with a batch of matrices of the same length, pairwise
M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((batch_size, m, p))
# python >= 3.5
MN = M @ N
# or the old way,
MN = tf.matmul(M, N)
# MN has shape (batch_size, n, p)
2. I want to multiply a batch of matrices with a batch of vectors of the same length, pairwise
We fall back to case 1 by adding and removing a dimension to v
.
M = tf.random_normal((batch_size, n, m))
v = tf.random_normal((batch_size, m))
Mv = (M @ v[..., None])[..., 0]
# Mv has shape (batch_size, n)
3. I want to multiply a single matrix with a batch of matrices
In this case, we cannot simply add a batch dimension of 1
to the single matrix, because tf.matmul
does not broadcast in the batch dimension.
3.1. The single matrix is on the right side
In that case, we can treat the matrix batch as a single large matrix, using a simple reshape.
M = tf.random_normal((batch_size, n, m))
N = tf.random_normal((m, p))
MN = tf.reshape(tf.reshape(M, [-1, m]) @ N, [-1, n, p])
# MN has shape (batch_size, n, p)
3.2. The single matrix is on the left side
This case is more complicated. We can fall back to case 3.1 by transposing the matrices.
MT = tf.matrix_transpose(M)
NT = tf.matrix_transpose(N)
NTMT = tf.reshape(tf.reshape(NT, [-1, m]) @ MT, [-1, p, n])
MN = tf.matrix_transpose(NTMT)
However, transposition can be a costly operation, and here it is done twice on an entire batch of matrices. It may be better to simply duplicate M
to match the batch dimension:
MN = tf.tile(M[None], [batch_size, 1, 1]) @ N
Profiling will tell which option works better for a given problem/hardware combination.
4. I want to multiply a single matrix with a batch of vectors
This looks similar to case 3.2 since the single matrix is on the left, but it is actually simpler because transposing a vector is essentially a no-op. We end-up with
M = tf.random_normal((n, m))
v = tf.random_normal((batch_size, m))
MT = tf.matrix_transpose(M)
Mv = v @ MT
What about einsum
?
All of the previous multiplications could have been written with the tf.einsum
swiss army knife. For example the first solution for 3.2 could be written simply as
MN = tf.einsum('nm,bmp->bnp', M, N)
However, note that einsum
is ultimately relying on tranpose
and matmul
for the computation.
So even though einsum
is a very convenient way to write matrix multiplications, it hides the complexity of the operations underneath — for example it is not straightforward to guess how many times an einsum
expression will transpose your data, and therefore how costly the operation will be. Also, it may hide the fact that there could be several alternatives for the same operation (see case 3.2) and might not necessarily choose the better option.
For this reason, I would personally use explicit formulas like those above to better convey their respective complexity. Although if you know what you are doing and like the simplicity of the einsum
syntax, then by all means go for it.