6

As per this answer, I've created the following test program:

#include <iso646.h>
#include <immintrin.h>

#include <stdio.h>

#define SHIFT_LEFT( N ) \ 
\
    inline __m256i shift_left_##N ( __m256i A  ) { \
\
    if ( N == 0 ) return A; \
    else if ( N <  16 ) return _mm256_alignr_epi8 ( A, _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ), ( uint8_t ) ( 16 - N ) ); \
    else if ( N == 16 ) return _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ); \
    else return _mm256_slli_si256 ( _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ), ( uint8_t ) ( N - 16 ) ); \
}

void print ( const size_t n ) {

    size_t i = 0x8000000000000000;

    while ( i ) {

        putchar ( ( int ) ( n & i ) + ( int ) ( 48 ) );
        i >>= 1;
        putchar ( ( int ) ( n & i ) + ( int ) ( 48 ) );
        i >>= 1;

        putchar ( ' ' );
    }
}

SHIFT_LEFT ( 2 );

int main ( ) {

    __m256i a = _mm256_set_epi64x ( 0x00, 0x00, 0x00, 0x03 );
    __m256i b = shift_left_2 ( a );

    size_t * c = ( size_t * ) &b;

    print ( c [ 3 ] ); print ( c [ 2 ] ); print ( c [ 1 ] ); print ( c [ 0 ] ); putchar ( '\n' );

    return 0;
}

The above program does not give the expected (by me) output, as far as I can see. I'm stumped as to how these functions work together (read the descriptions). Am I doing something wrong, or is the implementation of shift_left() wrong?

EDIT1: I came to realize (and confirmed in the comments) that this code only intends to shift by max 32 (and are bytes), so it does not satisfy my goal. Which leaves the question, "How to implement lane crossing logical bit-wise shift (left and right) in AVX2".

EDIT2: Fast forward: In the meanwhile, I'm less stumped as to how it works and have coded what I needed. I've posted the code (shift and rotate) and accepted that as the answer.

degski
  • 642
  • 5
  • 13

2 Answers2

9

Probably not the kind of answer that you're expecting. But here's a reasonably efficient solution that actually works for a run-time shift amount.

The costs are:

  • Preprocess: ~12 - 14 instructions
  • Rotation: 5 instructions
  • Shift: 6 instructions

In order to shift or rotate anything, you must first preprocess the shift amount. Once you have that, you can efficiently perform shifts/rotations.

Because the preprocessing step is so expensive, this solution utilizes an object to hold the preprocessed shift amount so that it can be reused many times when shifting by the same amount.

For efficiency, the object should be on the stack in the same scope as the code that does the shifting. This allows the compiler to promote all the fields of the object into registers. Furthermore, it's recommended to force-inline all the methods of the class.

#include <stdint.h>
#include <immintrin.h>

class LeftShifter_AVX2{
public:
    LeftShifter_AVX2(uint32_t bits){
        //  Precompute all the necessary values.
        permL = _mm256_sub_epi32(
            _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7),
            _mm256_set1_epi32(bits / 32)
        );
        permR = _mm256_sub_epi32(permL, _mm256_set1_epi32(1));

        bits %= 32;
        shiftL = _mm_cvtsi32_si128(bits);
        shiftR = _mm_cvtsi32_si128(32 - bits);
        __m256i maskL = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permL);
        __m256i maskR = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permR);
        mask = _mm256_or_si256(maskL, _mm256_srl_epi32(maskR, shiftR));
    }

    __m256i rotate(__m256i x) const{
        __m256i L = _mm256_permutevar8x32_epi32(x, permL);
        __m256i R = _mm256_permutevar8x32_epi32(x, permR);
        L = _mm256_sll_epi32(L, shiftL);
        R = _mm256_srl_epi32(R, shiftR);
        return _mm256_or_si256(L, R);
    }
    __m256i shift(__m256i x) const{
        return _mm256_andnot_si256(mask, rotate(x));
    }

private:
    __m256i permL;
    __m256i permR;
    __m128i shiftL;
    __m128i shiftR;
    __m256i mask;
};

Test Program:

#include <iostream>
using namespace std;

void print_u8(__m256i x){
    union{
        __m256i v;
        uint8_t s[32];
    };
    v = x;
    for (int c = 0; c < 32; c++){
        cout << (int)s[c] << " ";
    }
    cout << endl;
}

int main(){
    union{
        __m256i x;
        char buffer[32];
    };
    for (int c = 0; c < 32; c++){
        buffer[c] = (char)c;
    }
    print_u8(x);
    print_u8(LeftShifter_AVX2(0).shift(x));
    print_u8(LeftShifter_AVX2(8).shift(x));
    print_u8(LeftShifter_AVX2(32).shift(x));
    print_u8(LeftShifter_AVX2(40).shift(x));
}

Output:

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 
0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 
0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Right-shift is very similar. I'll leave that as an exercise for the reader.

Mysticial
  • 464,885
  • 45
  • 335
  • 332
  • Interesting. The "run-time"-functions are as fast (no _mm256_permute4x64_epi64) as it's gonna get, I think, but at a price. I would like to write this in c, so I'll have to think how this works in c (the stack caveat). – degski Mar 21 '18 at 13:37
  • The object doesn't need to be a temporary. Just declare it as a local variable and it will work. What you want to avoid is passing it into other places either by pointer or by reference. This will break compiler optimizations since the object is no longer a local variable. If you need to pass it around, either do it by value or make a local copy of it before you use it. – Mysticial Mar 21 '18 at 16:20
  • For generalized shifting this then becomes quite messy. I realized that passing by pointer is a no go, but I'm surprised (a little bit) that a reference (in C++) breaks scope as well, but I guess that since they are likely implemented as pointer-to-pointer vars with auto-deref, one could not expect that. It's not a cheap operation, wtf, a simple left or right shift! – degski Mar 22 '18 at 13:12
  • Another option is to simply not cache the preprocessing. Just wrap it into a single function and not bother with the object at all. But then you'll be relying entirely on the compiler to cache it with CSE (Common Subexpression Elimination). Fundamentally, shifting a SIMD vector is not SIMD. So there's no hardware for it. And TBH, there aren't a lot of legitimate uses for a full-vector shift anyway. – Mysticial Mar 22 '18 at 16:02
  • I'm coding a tritset (like a std::bitset, but with false, true and unknown) in avx2, for that I need at least shift by one to implement the logic. – degski Mar 22 '18 at 16:20
  • @Mysticial, your Stack Overflow profile page has an invalid link "22.4 trillion digits (November 2016)". Sorry its off-topic here. – Boris Oct 02 '18 at 17:50
  • @Boris Thanks. Hate external links for this reason. The information is easily googled, so I'll just unlink it. – Mysticial Oct 02 '18 at 19:37
3

The following code implements lane-crossing logical bit-wise shift/rotate (left and right) in AVX2:

// Prototypes...

__m256i _mm256_sli_si256 ( __m256i, int );
__m256i _mm256_sri_si256 ( __m256i, int );
__m256i _mm256_rli_si256 ( __m256i, int );
__m256i _mm256_rri_si256 ( __m256i, int );


// Implementations...

__m256i left_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 0 ) ), _MM_SHUFFLE ( 3, 3, 3, 0 ) ) );
}

__m256i left_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 3, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 3, 0, 0 ) ); // 6

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 0, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 0, 0, 0 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_slli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 0 ) ), n ), _MM_SHUFFLE ( 3, 0, 0, 0 ) );
}

__m256i _mm256_sli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_shift_000_063 ( a, n ) : left_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? left_shift_128_191 ( a, n % 64 ) : left_shift_192_255 ( a, n % 64 );
}


__m256i right_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ), _MM_SHUFFLE ( 0, 3, 3, 3 ) ) );
}

__m256i right_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 3, 2, 1 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 3, 3, 2 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 3, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 3, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 2, 3, 2 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 2, 1, 3 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 0, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 0, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_srli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 3 ) ), n ), _MM_SHUFFLE ( 0, 0, 0, 3 ) );
}

__m256i _mm256_sri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_shift_000_063 ( a, n ) : right_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? right_shift_128_191 ( a, n % 64 ) : right_shift_192_255 ( a, n % 64 );
}


__m256i left_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i left_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 3 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 0, 3, 2, 1 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i _mm256_rli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_rotate_000_063 ( a, n ) : left_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? left_rotate_128_191 ( a, n % 64 ) : left_rotate_192_255 ( a, n % 64 );
}


__m256i right_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i right_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 0, 3, 2, 1 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i right_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 2, 1, 0, 3 ) );

    return _mm256_or_si256 ( d, e );
}
__m256i right_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i _mm256_rri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_rotate_000_063 ( a, n      ) : right_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? right_rotate_128_191 ( a, n % 64 ) : right_rotate_192_255 ( a, n % 64 );
}

I have tried to make the _mm256_permute4x64_epi64 ops (when there in any case have to be two) to partially overlap, which should keep the overall latency to a minimum.

Most of the suggestions and or clues given by commenters were helpful in putting together the code, thanks to those. Obviously, improvements and or any other comments are welcome.

I think that Mystical's answer is interesting, but too complicated to be used effectively for generalized shifting/rotating for use f.e. in a library.

degski
  • 642
  • 5
  • 13