It sometimes helps to write expressions as matrix-vector products. Assuming you already know sₖ₊₈
you can calculate sₖ
to sₖ₊₇
from aₖ
to aₖ₊₇
using
[ µ µ² µ³ µ⁴ µ⁵ µ⁶ µ⁷ µ⁸] [aₖ₊₀ ]
[ 0 µ µ² µ³ µ⁴ µ⁵ µ⁶ µ⁷] [aₖ₊₁ ]
[ 0 0 µ µ² µ³ µ⁴ µ⁵ µ⁶] [aₖ₊₂ ]
[ 0 0 0 µ µ² µ³ µ⁴ µ⁵] [aₖ₊₃ ]
[ 0 0 0 0 µ µ² µ³ µ⁴] * [aₖ₊₄ ]
[ 0 0 0 0 0 µ µ² µ³] [aₖ₊₅ ]
[ 0 0 0 0 0 0 µ µ²] [aₖ₊₆ ]
[ 0 0 0 0 0 0 0 µ ] [aₖ₊₇+sₖ₊₈]
Since the sₖ₊₈
will likely have some latency when this is calculated, it makes sense to move it out of the product. This can be calculated with one broadcast and one fused-multiple-add:
[ µ µ² µ³ µ⁴ µ⁵ µ⁶ µ⁷ µ⁸] [aₖ₊₀] [ µ⁸]
[ 0 µ µ² µ³ µ⁴ µ⁵ µ⁶ µ⁷] [aₖ₊₁] [ µ⁷]
[ 0 0 µ µ² µ³ µ⁴ µ⁵ µ⁶] [aₖ₊₂] [ µ⁶]
[ 0 0 0 µ µ² µ³ µ⁴ µ⁵] [aₖ₊₃] [ µ⁵]
[ 0 0 0 0 µ µ² µ³ µ⁴] * [aₖ₊₄] + [ µ⁴] * sₖ₊₈
[ 0 0 0 0 0 µ µ² µ³] [aₖ₊₅] [ µ³]
[ 0 0 0 0 0 0 µ µ²] [aₖ₊₆] [ µ²]
[ 0 0 0 0 0 0 0 µ ] [aₖ₊₇] [ µ ]
And the first matrix can be decomposed into three matrices which can be calculated using one shuffle and one FMA each:
[ 1 0 0 0 µ⁴ 0 0 0 ] [ 1 0 µ² 0 0 0 0 0 ] [ µ µ² 0 0 0 0 0 0 ] [aₖ₊₀] [ µ⁸]
[ 0 1 0 0 µ³ 0 0 0 ] [ 0 1 µ 0 0 0 0 0 ] [ 0 µ 0 0 0 0 0 0 ] [aₖ₊₁] [ µ⁷]
[ 0 0 1 0 µ² 0 0 0 ] [ 0 0 1 0 0 0 0 0 ] [ 0 0 µ µ² 0 0 0 0 ] [aₖ₊₂] [ µ⁶]
[ 0 0 0 1 µ 0 0 0 ] [ 0 0 0 1 0 0 0 0 ] [ 0 0 0 µ 0 0 0 0 ] [aₖ₊₃] [ µ⁵]
[ 0 0 0 0 1 0 0 0 ] * [ 0 0 0 0 1 0 µ² 0 ] * [ 0 0 0 0 µ µ² 0 0 ] * [aₖ₊₄] + [ µ⁴] * sₖ₊₈
[ 0 0 0 0 0 1 0 0 ] [ 0 0 0 0 0 1 µ 0 ] [ 0 0 0 0 0 µ 0 0 ] [aₖ₊₅] [ µ³]
[ 0 0 0 0 0 0 1 0 ] [ 0 0 0 0 0 0 1 0 ] [ 0 0 0 0 0 0 µ µ²] [aₖ₊₆] [ µ²]
[ 0 0 0 0 0 0 0 1 ] [ 0 0 0 0 0 0 0 1 ] [ 0 0 0 0 0 0 0 µ ] [aₖ₊₇] [ µ ]
The right-most matrix-vector-product is actually one multiplication more.
Overall, for 8 elements you need 4FMAs, one multiplication and 4 shuffles/broadcasts ($$$$
means anything (finite) can be here -- alternatively, if these are guaranteed to be 0, the µ
vectors could be partially shared. All vectors are notated least-significant-first, all multiplications are element-wise):
bₖₖ₊₇ = [aₖ₊₀, aₖ₊₁, aₖ₊₂, aₖ₊₃, aₖ₊₄, aₖ₊₅, aₖ₊₆, aₖ₊₇] * [µ µ µ µ µ µ µ µ ] vmulps
bₖₖ₊₇ += [aₖ₊₁, $$$$, aₖ₊₃, $$$$, aₖ₊₅, $$$$, aₖ₊₆, $$$$] * [µ² 0 µ² 0 µ² 0 µ² 0 ] vshufps (or vpsrlq) + vfmadd
cₖₖ₊₇ = bₖₖ₊₇
cₖₖ₊₇ += [bₖ₊₂, bₖ₊₂, $$$$, $$$$, bₖ₊₆, bₖ₊₆, $$$$, $$$$] * [µ² µ 0 0 µ² µ 0 0 ] vshufps + vfmadd
dₖₖ₊₇ = cₖₖ₊₇
dₖₖ₊₇ += [cₖ₊₄, cₖ₊₄, cₖ₊₄, cₖ₊₄, $$$$, $$$$, $$$$, $$$$] * [µ⁴ µ³ µ² µ 0 0 0 0 ] vpermps + vfmadd
sₖₖ₊₇ = dₖₖ₊₇
+ [sₖ₊₈, sₖ₊₈, sₖ₊₈, sₖ₊₈, sₖ₊₈, sₖ₊₈, sₖ₊₈, sₖ₊₈] * [µ⁸ µ⁷ µ⁶ µ⁵ µ⁴ µ³ µ² µ ] vbroadcastss + vfmadd
If I analyzed it correctly, the calculation of multiple dₖ
can interleave which would cancel out latencies. And the only hot-path would be the final vbroadcastss + vfmadd
to calculate sₖₖ₊₇
from dₖₖ₊₇
and sₖ₊₈
. It could be worthwhile to calculate blocks of 16 sₖ
and fmadd
µⁱ * sₖ₊₁₆
to that.
Also, the FMAs to calculate dₖ
only use half of the elements. With some sophisticated swizzling one could calculate two blocks with the same number of FMAs (I assume this is not worth the effort -- but feel free to try that out).
For comparison: A pure scalar implementation requires 8 additions and 8 multiplications for 8 elements, and every operation depends on the previous result.
N.B. You could save one multiplication, if instead of your formula you calculated:
sₖ = aₖ₊₁ + µ*sₖ₊₁
also, in a scalar version you would have Fused-Multiple-Adds, instead of first adding and multiplying afterwards. The result would only differ by a factor of µ
.