I have tried speeding up a toy GEMM implementation. I deal with blocks of 32x32 doubles for which I need an optimized MM kernel. I have access to AVX2 and FMA.
I have two codes (in ASM, I apologies for the crudeness of the formatting) defined below, one is making use of AVX2 features, the other uses FMA.
Without going into micro benchmarks, I would like to try to develop an understanding (theoretical) of why the AVX2 implementation is 1.11x faster than the FMA version. And possibly how to improve both versions.
The codes below are for a 3000x3000 MM of doubles and the kernels are implemented using the classical, naive MM with an interchanged deepest loop. I'm using a Ryzen 3700x/Zen 2 as development CPU.
I have not tried unrolling aggressively, in fear that the CPU might run out of physical registers.
AVX2 32x32 MM kernel:
Block 82:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 83:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
vaddpd ymm2, ymm10, ymm2
vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
vaddpd ymm3, ymm11, ymm3
vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
vaddpd ymm0, ymm9, ymm0
vaddpd ymm1, ymm12, ymm1
vaddpd ymm4, ymm10, ymm4
vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
vmulpd ymm8, ymm8, ymmword ptr [rax]
vaddpd ymm5, ymm11, ymm5
add rax, 0x5dc0
vaddpd ymm6, ymm10, ymm6
vaddpd ymm7, ymm8, ymm7
cmp r13, 0x20
jnz 0x140004530 <Block 83>
Block 84:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400044d0 <Block 82>
AVX2/FMA 32x32 MM kernel:
Block 80:
imul r12, r15, 0xbb8
mov rax, r11
mov r13d, 0x0
vmovupd ymm0, ymmword ptr [rdi+r12*8]
vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
lea r14, ptr [r12+0x4]
nop dword ptr [rax+rax*1], eax
Block 81:
vbroadcastsd ymm8, qword ptr [rcx+r13*8]
inc r13
vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
add rax, 0x5dc0
cmp r13, 0x20
jnz 0x140004450
Block 82:
inc r15
add rcx, 0x5dc0
vmovupd ymmword ptr [rdi+r12*8], ymm0
vmovupd ymmword ptr [rdi+r14*8], ymm1
vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
cmp r15, 0x20
jnz 0x1400043f0 <Block 80>