I found the following code in C++ for fast transposition of an 8x8 matrix of 32-bit values: https://stackoverflow.com/a/51887176/1915854
inline void Transpose8x8Shuff(unsigned long *in)
{
__m256 *inI = reinterpret_cast<__m256 *>(in);
__m256 rI[8];
rI[0] = _mm256_unpacklo_ps(inI[0], inI[1]);
rI[1] = _mm256_unpackhi_ps(inI[0], inI[1]);
rI[2] = _mm256_unpacklo_ps(inI[2], inI[3]);
rI[3] = _mm256_unpackhi_ps(inI[2], inI[3]);
rI[4] = _mm256_unpacklo_ps(inI[4], inI[5]);
rI[5] = _mm256_unpackhi_ps(inI[4], inI[5]);
rI[6] = _mm256_unpacklo_ps(inI[6], inI[7]);
rI[7] = _mm256_unpackhi_ps(inI[6], inI[7]);
__m256 rrF[8];
__m256 *rF = reinterpret_cast<__m256 *>(rI);
rrF[0] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(1,0,1,0));
rrF[1] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(3,2,3,2));
rrF[2] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(1,0,1,0));
rrF[3] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(3,2,3,2));
rrF[4] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(1,0,1,0));
rrF[5] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(3,2,3,2));
rrF[6] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(1,0,1,0));
rrF[7] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(3,2,3,2));
rF = reinterpret_cast<__m256 *>(in);
rF[0] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x20);
rF[1] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x20);
rF[2] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x20);
rF[3] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x20);
rF[4] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x31);
rF[5] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x31);
rF[6] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x31);
rF[7] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x31);
}
However, converting it to Java vector API ( https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/IntVector.html ) is not straightforward, because the Java vector API doesn't map directly to CPU instructions / C++ intrinsics.
Can you share what the equivalents of the following intrinsics/macros in Java are?
_mm256_unpacklo_ps()
_mm256_unpackhi_ps()
_mm256_shuffle_ps()
_MM_SHUFFLE()
_mm256_permute2f128_ps()
I can use the latest JDK 19.
UPDATE: following the suggestion by @Soonts , I've implemented the following, and it passes tests, but it's terribly slow:
public class SimdOps {
public static final VectorSpecies<Integer> SPECIES_INT = IntVector.SPECIES_256;
public static final VectorSpecies<Long> SPECIES_LONG = LongVector.SPECIES_256;
public static final VectorShuffle<Integer> vsUnpackLo = VectorShuffle.fromValues(SPECIES_INT, 0, -8, 1, -7, 4, -4,
5, -3);
public static final VectorShuffle<Integer> vsUnpackHi = VectorShuffle.fromValues(SPECIES_INT, 2, -6, 3, -5, 6, -2,
7, -1);
public static final VectorShuffle<Integer> vsShuffle1010 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, -8, -7, 4,
5, -4, -3);
public static final VectorShuffle<Integer> vsShuffle3232 = VectorShuffle.fromValues(SPECIES_INT, 2, 3, -6, -5, 6, 7,
-2, -1);
public static final VectorShuffle<Integer> vsPermute0x20 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, 2, 3, -8, -7,
-6, -5);
public static final VectorShuffle<Integer> vsPermute0x31 = VectorShuffle.fromValues(SPECIES_INT, 4, 5, 6, 7, -4, -3,
-2, -1);
// Transpose 8x8 matrix of 32-bit integers, stored in 256-bit SIMD vectors
public static final void transpose8x8(IntVector[] inpM) {
assert inpM.length == Constants.INTS_PER_SIMD;
// https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
// https://stackoverflow.com/questions/73977998/simd-transposition-of-8x8-matrix-of-32-bit-values-in-java
final IntVector rI0 = inpM[0].rearrange(vsUnpackLo, inpM[1]);
final IntVector rI1 = inpM[0].rearrange(vsUnpackHi, inpM[1]);
final IntVector rI2 = inpM[2].rearrange(vsUnpackLo, inpM[3]);
final IntVector rI3 = inpM[2].rearrange(vsUnpackHi, inpM[3]);
final IntVector rI4 = inpM[4].rearrange(vsUnpackLo, inpM[5]);
final IntVector rI5 = inpM[4].rearrange(vsUnpackHi, inpM[5]);
final IntVector rI6 = inpM[6].rearrange(vsUnpackLo, inpM[7]);
final IntVector rI7 = inpM[6].rearrange(vsUnpackHi, inpM[7]);
final IntVector rrF0 = rI0.rearrange(vsShuffle1010, rI2);
final IntVector rrF1 = rI0.rearrange(vsShuffle3232, rI2);
final IntVector rrF2 = rI1.rearrange(vsShuffle1010, rI3);
final IntVector rrF3 = rI1.rearrange(vsShuffle3232, rI3);
final IntVector rrF4 = rI4.rearrange(vsShuffle1010, rI6);
final IntVector rrF5 = rI4.rearrange(vsShuffle3232, rI6);
final IntVector rrF6 = rI5.rearrange(vsShuffle1010, rI7);
final IntVector rrF7 = rI5.rearrange(vsShuffle3232, rI7);
inpM[0] = rrF0.rearrange(vsPermute0x20, rrF4);
inpM[1] = rrF1.rearrange(vsPermute0x20, rrF5);
inpM[2] = rrF2.rearrange(vsPermute0x20, rrF6);
inpM[3] = rrF3.rearrange(vsPermute0x20, rrF7);
inpM[4] = rrF0.rearrange(vsPermute0x31, rrF4);
inpM[5] = rrF1.rearrange(vsPermute0x31, rrF5);
inpM[6] = rrF2.rearrange(vsPermute0x31, rrF6);
inpM[7] = rrF3.rearrange(vsPermute0x31, rrF7);
}
};
And the bottleneck is jdk.incubator.vector.Int256Vector.rearrange(VectorShuffle, Vector)
. It's at least 10 times slower than the scalar code. Any ideas?