1

I found the following code in C++ for fast transposition of an 8x8 matrix of 32-bit values: https://stackoverflow.com/a/51887176/1915854

  inline void Transpose8x8Shuff(unsigned long *in)
  {
     __m256 *inI = reinterpret_cast<__m256 *>(in);
     __m256 rI[8];
     rI[0] = _mm256_unpacklo_ps(inI[0], inI[1]); 
     rI[1] = _mm256_unpackhi_ps(inI[0], inI[1]); 
     rI[2] = _mm256_unpacklo_ps(inI[2], inI[3]); 
     rI[3] = _mm256_unpackhi_ps(inI[2], inI[3]); 
     rI[4] = _mm256_unpacklo_ps(inI[4], inI[5]); 
     rI[5] = _mm256_unpackhi_ps(inI[4], inI[5]); 
     rI[6] = _mm256_unpacklo_ps(inI[6], inI[7]); 
     rI[7] = _mm256_unpackhi_ps(inI[6], inI[7]); 

     __m256 rrF[8];
     __m256 *rF = reinterpret_cast<__m256 *>(rI);
     rrF[0] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(1,0,1,0));
     rrF[1] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(3,2,3,2));
     rrF[2] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(1,0,1,0)); 
     rrF[3] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(3,2,3,2));
     rrF[4] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(1,0,1,0));
     rrF[5] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(3,2,3,2));
     rrF[6] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(1,0,1,0));
     rrF[7] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(3,2,3,2));

     rF = reinterpret_cast<__m256 *>(in);
     rF[0] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x20);
     rF[1] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x20);
     rF[2] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x20);
     rF[3] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x20);
     rF[4] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x31);
     rF[5] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x31);
     rF[6] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x31);
     rF[7] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x31);
  }

However, converting it to Java vector API ( https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/IntVector.html ) is not straightforward, because the Java vector API doesn't map directly to CPU instructions / C++ intrinsics.

Can you share what the equivalents of the following intrinsics/macros in Java are?

  1. _mm256_unpacklo_ps()
  2. _mm256_unpackhi_ps()
  3. _mm256_shuffle_ps()
  4. _MM_SHUFFLE()
  5. _mm256_permute2f128_ps()

I can use the latest JDK 19.

UPDATE: following the suggestion by @Soonts , I've implemented the following, and it passes tests, but it's terribly slow:

public class SimdOps {
    public static final VectorSpecies<Integer> SPECIES_INT = IntVector.SPECIES_256;
    public static final VectorSpecies<Long> SPECIES_LONG = LongVector.SPECIES_256;
    public static final VectorShuffle<Integer> vsUnpackLo = VectorShuffle.fromValues(SPECIES_INT, 0, -8, 1, -7, 4, -4,
            5, -3);
    public static final VectorShuffle<Integer> vsUnpackHi = VectorShuffle.fromValues(SPECIES_INT, 2, -6, 3, -5, 6, -2,
            7, -1);
    public static final VectorShuffle<Integer> vsShuffle1010 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, -8, -7, 4,
            5, -4, -3);
    public static final VectorShuffle<Integer> vsShuffle3232 = VectorShuffle.fromValues(SPECIES_INT, 2, 3, -6, -5, 6, 7,
            -2, -1);
    public static final VectorShuffle<Integer> vsPermute0x20 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, 2, 3, -8, -7,
            -6, -5);
    public static final VectorShuffle<Integer> vsPermute0x31 = VectorShuffle.fromValues(SPECIES_INT, 4, 5, 6, 7, -4, -3,
            -2, -1);

    // Transpose 8x8 matrix of 32-bit integers, stored in 256-bit SIMD vectors
    public static final void transpose8x8(IntVector[] inpM) {
        assert inpM.length == Constants.INTS_PER_SIMD;
        // https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
        // https://stackoverflow.com/questions/73977998/simd-transposition-of-8x8-matrix-of-32-bit-values-in-java
        final IntVector rI0 = inpM[0].rearrange(vsUnpackLo, inpM[1]);
        final IntVector rI1 = inpM[0].rearrange(vsUnpackHi, inpM[1]);
        final IntVector rI2 = inpM[2].rearrange(vsUnpackLo, inpM[3]);
        final IntVector rI3 = inpM[2].rearrange(vsUnpackHi, inpM[3]);
        final IntVector rI4 = inpM[4].rearrange(vsUnpackLo, inpM[5]);
        final IntVector rI5 = inpM[4].rearrange(vsUnpackHi, inpM[5]);
        final IntVector rI6 = inpM[6].rearrange(vsUnpackLo, inpM[7]);
        final IntVector rI7 = inpM[6].rearrange(vsUnpackHi, inpM[7]);
        
        final IntVector rrF0 = rI0.rearrange(vsShuffle1010, rI2);
        final IntVector rrF1 = rI0.rearrange(vsShuffle3232, rI2);
        final IntVector rrF2 = rI1.rearrange(vsShuffle1010, rI3);
        final IntVector rrF3 = rI1.rearrange(vsShuffle3232, rI3);
        final IntVector rrF4 = rI4.rearrange(vsShuffle1010, rI6);
        final IntVector rrF5 = rI4.rearrange(vsShuffle3232, rI6);
        final IntVector rrF6 = rI5.rearrange(vsShuffle1010, rI7);
        final IntVector rrF7 = rI5.rearrange(vsShuffle3232, rI7);

        inpM[0] = rrF0.rearrange(vsPermute0x20, rrF4);
        inpM[1] = rrF1.rearrange(vsPermute0x20, rrF5);
        inpM[2] = rrF2.rearrange(vsPermute0x20, rrF6);
        inpM[3] = rrF3.rearrange(vsPermute0x20, rrF7);
        inpM[4] = rrF0.rearrange(vsPermute0x31, rrF4);
        inpM[5] = rrF1.rearrange(vsPermute0x31, rrF5);
        inpM[6] = rrF2.rearrange(vsPermute0x31, rrF6);
        inpM[7] = rrF3.rearrange(vsPermute0x31, rrF7);
    }
};

And the bottleneck is jdk.incubator.vector.Int256Vector.rearrange(VectorShuffle, Vector) . It's at least 10 times slower than the scalar code. Any ideas?

Serge Rogatch
  • 13,865
  • 7
  • 86
  • 158
  • Java aims to be platform-independent. These seem to be (Intel-)platform dependent operations. How should there be a direct replacement in Java? One option would be call the C++ method via JNI. see https://docs.oracle.com/en/java/javase/11/docs/specs/jni/index.html or https://www.baeldung.com/jni – nineninesevenfour Oct 06 '22 at 18:00
  • 1
    @nineninesevenfour: Maybe you missed the [project-panama] tag? Allowing stuff like this more efficiently and/or easily than traditional JNI seems to be the whole point of https://openjdk.org/projects/panama/. And the fact that Java `jdk.incubator.vector` (https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/package-summary.html) is supposed to give access to something like intrinsics directly, for the JVM's JIT to emit such instructions. – Peter Cordes Oct 06 '22 at 18:08
  • @nineninesevenfour, what is the overhead of switching the control flow from Java to C++? I'm afraid it will be too high given that this method is called tens of millions of times per second and there is no way to rewrite the higher-level methods in C++. – Serge Rogatch Oct 06 '22 at 18:08
  • 3
    The docs (https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/Vector.html#cross-lane) say that shuffles should be done with `rearrange()`. Presumably a shuffle optimizer will figure out how to actually implement the shuffles you request in terms of machine instructions, so choose ones that can be done in a single instruction on the platform you care about. Similar idea to A. Fog's VCL where the template `permute<2,3,0,1>(v)` picks shuffles https://github.com/vectorclass/version2/blob/08959ebe6ea5d8317330b242e28ba0d2938ac52f/vectori256.h#L3862 – Peter Cordes Oct 06 '22 at 18:19
  • @SergeRogatch: Yes, I have overlooked the tag. My apologies. – nineninesevenfour Oct 06 '22 at 20:27

1 Answers1

2

Disclaimer: I never wrote anything similar in Java.

Based on the documentation, the rearrange seems the only way to go.
The only issue is how to translate C intrinsics into the integers for the VectorShuffle<Float>.

Here's C++ code to find out:

void printShuffle( __m256 v, const char* name )
{
    __m256i iv = _mm256_cvtps_epi32( v );
    std::array<int, 8> a;
    _mm256_storeu_si256( ( __m256i* )a.data(), iv );
    printf( "%s: %i, %i, %i, %i, %i, %i, %i, %i\n", name,
        a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] );

}
#define TEST( expr ) printShuffle( expr, #expr )

void printJavaRearranges()
{
    const __m256 a = _mm256_setr_ps( 0, 1, 2, 3, 4, 5, 6, 7 );
    const __m256 b = _mm256_sub_ps( a, _mm256_set1_ps( 8 ) );
    TEST( _mm256_unpacklo_ps( a, b ) );
    TEST( _mm256_unpackhi_ps( a, b ) );
    TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ) );
    TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ) );
    TEST( _mm256_permute2f128_ps( a, b, 0x20 ) );
    TEST( _mm256_permute2f128_ps( a, b, 0x31 ) );
}

Output:

_mm256_unpacklo_ps( a, b ): 0, -8, 1, -7, 4, -4, 5, -3
_mm256_unpackhi_ps( a, b ): 2, -6, 3, -5, 6, -2, 7, -1
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ): 0, 1, -8, -7, 4, 5, -4, -3
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ): 2, 3, -6, -5, 6, 7, -2, -1
_mm256_permute2f128_ps( a, b, 0x20 ): 0, 1, 2, 3, -8, -7, -6, -5
_mm256_permute2f128_ps( a, b, 0x31 ): 4, 5, 6, 7, -4, -3, -2, -1

The _mm256_permute2f128_ps instruction can selectively zero out lanes, Java's vector API probably can't do that. Fortunately, the immediate values in your source code don't zero out any pieces.

If you're lucky, the runtime might map these values (when they are known to JIT in advance and never change) into the corresponding AVX instructions.

Soonts
  • 20,079
  • 9
  • 57
  • 130
  • Thanks, your answer yields correct results, but the `rearrange()` based code is extremely slow. Please, see the update from my question. – Serge Rogatch Oct 10 '22 at 12:16
  • @SergeRogatch You could try calling your function few hundred times, and only then measure the performance. If you’re lucky, JVM may eventually re-compile your Java code into the proper AVX shuffle instructions. If it will still be slow, you could file a performance bug to Oracle, and in the meantime try integrating a native DLL written in C or C++ into your Java program. – Soonts Oct 10 '22 at 12:26
  • This function is called tens of millions of times throughout the program, and I measure its runtime by where it appears in the profiler. – Serge Rogatch Oct 10 '22 at 12:28
  • 1
    @SergeRogatch One more thing. Your Java code doesn’t rearrange floats, it rearranges integers. AVX2 doesn’t have most of these shuffles for integer lanes, all these instructions only support FP32 vectors. The only of these shuffles supported for integer vectors is `_mm256_permute2x128_si256`. Look for Java API to re-interpret int32 vectors into fp32 vectors, and use rearrange method on the fp32 vectors. – Soonts Oct 10 '22 at 12:31
  • AVX2 has `_mm256_unpacklo/hi_epi32` (`vpuncpkldq` / `hdq`). But yes, a 2-input variable-control shuffle like `vshufps` is unique to `float` vectors, so in C you'd `_mm256_castsi256_ps` to use that FP shuffle on integer data. If a JVM doesn't realize this, it might not find an efficient implementation of that `rearrange`. – Peter Cordes Oct 10 '22 at 15:01