I tried looping over each item, but is that what pytorch does?
The short answer is yes, loops are used, but it's more complicated than you probably think. If input
is a 2D tensor (a matrix), then the output of a linear operation is computed as input @ weight.T + bias
using an external BLAS library's GEMM operation. Otherwise it uses torch.matmul(input, weight.T) + bias
which uses broadcast semantics to compute a batched version of the operation. Broadcasting is a semantic, not an implementation, so how the broadcasting is performed is going to be backend-dependent. Ultimately some form of looping combined with parallel processing will be used for most of these implementation.
To go a little deeper, lets take a look at the PyTorch implementation of the linear layer. This quickly leads down some rabbit holes since PyTorch uses different backend libraries for performing linear algebra operations efficiently on the hardware available (libraries like oneAPI, Intel MKL, or MAGMA) but perhaps understanding some of the details can help.
Starting at the C++ entrypoint to nn.functional.linear
:
Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
if (input.is_mkldnn()) {
return at::mkldnn_linear(input, weight, bias);
}
if (input.dim() == 2 && bias.defined()) {
// Fused op is marginally faster.
return at::addmm(bias, input, weight.t());
}
auto output = at::matmul(input, weight.t());
if (bias.defined()) {
output.add_(bias);
}
return output;
}
There are three cases here.
input.is_mkldnn()
. This condition occurs if the input tensor is in the MKL-DNN format (Tensor.to_mkldnn
) and will make PyTorch use the at::mkldnn_linear
function, which in turn makes calls to ideep, which in turn makes calls to the oneDNN library (previous known as Intel MKL-DNN), which ultimately selects a specific general matrix-matrix multiplication (GEMM) routine dependent on platform and data types. The simplest implementation is the reference implementation, and from that we can see that they use a parallel-for loop (note the anonymous function they use uses a quadruple nested for-loop). In practice the reference implementation probably isn't used, instead, you would probably be calling the x86 optimized version compiled with the Xbyak JIT assembler to produce highly optimized code. I'm not going to pretend to follow all the details of the optimized code, but efficient GEMM is a heavily studied topic that I only have a passing knowledge of.
input.dim() == 2 && bias.defined()
. This condition means that input
is a 2D tensor (shape [B,M]
) and bias
is defined. In this case pytorch uses the addmm
function. This efficiently computes the output as input @ weight.T + bias
where @
is matrix multiplication. There are multiple implementations of addmm
registered in PyTorch depending on what types of tensors are being used. The dense-CPU specific version is here which eventually makes calls to an external BLAS library's GEMM subroutine. The backend used is likely Intel MKL but you can check using print(torch.__config__.parallel_info())
. Whichever BLAS implementation is being used, its certainly a highly optimized implementation of matrix multiplication similar to the oneDNN implementation, probably using multi-threading and optimized compilation.
If neither of the previous two conditions are met then PyTorch uses the torch.matmul
function, which performs a broadcasted version of input @ weight.T
where input
is shape [..., M]
. The result of this operation is a tensor of shape [..., N]
. Similar to addmm
, there are multiple implementations of this function depending on the tensor types but an external library will ultimately be used that uses parallelization and optimized matrix-multiplication subroutines. After the broadcasted matrix-multiplication a broadcasted add_
operation is used to add the bias
term (if bias
is defined).