11

I wrote the function int compare_16bytes(__m128i lhs, __m128i rhs) in order to compare two 16 byte numbers using SSE instructions: this function returns how many bytes are equal after performing the comparison.

Now I would like use the above function in order to compare two byte arrays of arbitrary length: the length may not be a multiple of 16 bytes, so I need deal with this problem. How could I complete the implementation of the function below? How could I improve the function below?

int fast_compare(const char* s, const char* t, int length)
{
    int result = 0;

    const char* sPtr = s;
    const char* tPtr = t;

    while(...)
    {
        const __m128i* lhs = (const __m128i*)sPtr;
        const __m128i* rhs = (const __m128i*)tPtr;

        // compare the next 16 bytes of s and t
        result += compare_16bytes(*lhs,*rhs);

        sPtr += 16;
        tPtr += 16;
    }

    return result;
}
enzom83
  • 8,080
  • 10
  • 68
  • 114
  • 2
    Use a for loop (length / 16 times), and pad zeros to lhs and ones to rhs if the remaining bytes are less than 16. The padding should be different so that it doesn't falsely count the padding as equal. – Oguz Meteer Mar 09 '13 at 17:42
  • 1
    `while (length >= 16) { /* use your function */ length -= 16; } if (length) /* use a version that compares length (up to 15) bytes */;` – pmg Mar 09 '13 at 17:42
  • 1
    FYI this is often called the [*Hamming distance*](http://en.wikipedia.org/wiki/Hamming_distance) – this may be useful as a search term. – Konrad Rudolph Mar 09 '13 at 18:02
  • The C library includes functions like `memset()` that work on any number of bytes, but must be fast. For speed these may be implemented as inline functions, so you might be able to find source for them in an include file. Studying how they are implemented may help you solve this problem. Also check Agner Fog's asm library: http://www.agner.org/optimize/#asmlib – steveha Mar 09 '13 at 18:12
  • @steveha: SSE instructions can simultaneously compare 16 bytes. – enzom83 Mar 09 '13 at 18:20
  • 2
    A better approach is to not use your `compare_16bytes` function at all and do a compare/accumulate vertically. Then at the end do a reduction. (You will also need to do a reduction every 255 iterations to keep the sum vector from overflowing.) – Mysticial Mar 09 '13 at 18:41
  • Related: [How to count character occurrences using SIMD](//stackoverflow.com/q/54541129) for counting matches on one byte. The counting is the same if the compare vector is `_mm_set1_epi8(c)` or if it's loaded from another array. – Peter Cordes Feb 11 '19 at 06:28

4 Answers4

6

As @Mysticial says in the comments above, do the compare and sum vertically and then just sum horizontally at the end of the main loop:

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <emmintrin.h>

// reference implementation
int fast_compare_ref(const char *s, const char *t, int length)
{
    int result = 0;
    int i;

    for (i = 0; i < length; ++i)
    {
        if (s[i] == t[i])
            result++;
    }
    return result;
}

// optimised implementation
int fast_compare(const char *s, const char *t, int length)
{
    int result = 0;
    int i;

    __m128i vsum = _mm_set1_epi32(0);
    for (i = 0; i < length - 15; i += 16)
    {
        __m128i vs, vt, v, vh, vl, vtemp;

        vs = _mm_loadu_si128((__m128i *)&s[i]); // load 16 chars from input
        vt = _mm_loadu_si128((__m128i *)&t[i]);
        v = _mm_cmpeq_epi8(vs, vt);             // compare
        vh = _mm_unpackhi_epi8(v, v);           // unpack compare result into 2 x 8 x 16 bit vectors
        vl = _mm_unpacklo_epi8(v, v);
        vtemp = _mm_madd_epi16(vh, vh);         // accumulate 16 bit vectors into 4 x 32 bit partial sums
        vsum = _mm_add_epi32(vsum, vtemp);
        vtemp = _mm_madd_epi16(vl, vl);
        vsum = _mm_add_epi32(vsum, vtemp);
    }

    // get sum of 4 x 32 bit partial sums
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8));
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4));
    result = _mm_cvtsi128_si32(vsum);

    // handle any residual bytes ( < 16)
    if (i < length)
    {
        result += fast_compare_ref(&s[i], &t[i], length - i);
    }

    return result;
}

// test harness
int main(void)
{
    const int n = 1000000;
    char *s = malloc(n);
    char *t = malloc(n);
    int i, result_ref, result;

    srand(time(NULL));

    for (i = 0; i < n; ++i)
    {
        s[i] = rand();
        t[i] = rand();
    }

    result_ref = fast_compare_ref(s, t, n);
    result = fast_compare(s, t, n);

    printf("result_ref = %d, result = %d\n", result_ref, result);;

    return 0;
}

Compile and run the above test harness:

$ gcc -Wall -O3 -msse3 fast_compare.c -o fast_compare
$ ./fast_compare
result_ref = 3955, result = 3955
$ ./fast_compare
result_ref = 3947, result = 3947
$ ./fast_compare
result_ref = 3945, result = 3945

Note that there is one possibly non-obvious trick in the above SSE code where we use _mm_madd_epi16 to unpack and accumulate 16 bit 0/-1 values to 32 bit partial sums. We take advantage of the fact that -1*-1 = 1 (and 0*0 = 0 of course) - we're not really doing a multiply here, just unpacking and summing in one instruction.


UPDATE: as noted in the comments below, this solution is not optimal - I just took a fairly optimal 16 bit solution and added 8 bit to 16 bit unpacking to make it work for 8 bit data. However for 8 bit data there are more efficient methods, e.g. using psadbw/_mm_sad_epu8. I'll leave this answer here for posterity, and for anyone who might want to do this kind of thing with 16 bit data, but really one of the other answers which doesn't require unpacking the input data should be the accepted answer.

Paul R
  • 208,748
  • 37
  • 389
  • 560
  • Great! It works properly! Moreover, is it important that the two vectors `s` and `t` are _aligned_? What is the alignment? – enzom83 Mar 11 '13 at 14:08
  • 1
    I've used `_mm_loadu_si128` in the above example so that it doesn't matter about alignment. If you can guarantee that `s` and `t` are 16 byte aligned though then use `_mm_load_si128` instead of `_mm_loadu_si128` for better performance, particularly on older CPUs. – Paul R Mar 11 '13 at 21:38
  • _mm_setzero_si128 () may be faster than _mm_set1_epi32(0) for zeroing vsum. – leecbaker Feb 03 '15 at 18:26
  • There shouldn't be any difference with a decent compiler, but yes, it might not be a bad idea all the same. – Paul R Feb 03 '15 at 18:36
  • there's a faster way to accumulate even without unrolling for `psubb`, using just `psadbw` / `paddq`. I turned my comments into an answer. – Peter Cordes Jun 21 '16 at 05:45
  • @PeterCordes: thanks, yes, I think this came up in a similar question recently - I've made a mental note to look out for opportunities to exploit `psadbw` In future. – Paul R Jun 21 '16 at 06:08
3

Using partial sums in 16 x uint8 elements may give even better performance.
I have divided the loop into inner loop and outer loop.
The inner loop sum uint8 elements (each uint8 element can sum up to 255 "1"s).
Small trick: _mm_cmpeq_epi8 set equal elements to 0xFF, and (char)0xFF = -1, so you can subtract the result from the sum (subtract -1 for adding 1).

Here is my optimized version for fast_compare:

int fast_compare2(const char *s, const char *t, int length)
{
    int result = 0;
    int inner_length = length;
    int i;
    int j = 0;

    //Points beginning of 4080 elements block.
    const char *s0 = s;
    const char *t0 = t;


    __m128i vsum = _mm_setzero_si128();

    //Outer loop sum result of 4080 sums.
    for (i = 0; i < length; i += 4080)
    {
        __m128i vsum_uint8 = _mm_setzero_si128(); //16 uint8 sum elements (each uint8 element can sum up to 255).
        __m128i vh, vl, vhl, vhl_lo, vhl_hi;

        //Points beginning of 4080 elements block.
        s0 = s + i;
        t0 = t + i;

        if (i + 4080 <= length)
        {
            inner_length = 4080;
        }
        else
        {
            inner_length = length - i;
        }

        //Inner loop - sum up to 4080 (compared) results.
        //Each uint8 element can sum up to 255. 16 uint8 elements can sum up to 255*16 = 4080 (compared) results.
        //////////////////////////////////////////////////////////////////////////
        for (j = 0; j < inner_length-15; j += 16)
        {
              __m128i vs, vt, v;

              vs = _mm_loadu_si128((__m128i *)&s0[j]); // load 16 chars from input
              vt = _mm_loadu_si128((__m128i *)&t0[j]);
              v = _mm_cmpeq_epi8(vs, vt);             // compare - set to 0xFF where equal, and 0 otherwise.

              //Consider this: (char)0xFF = (-1)
              vsum_uint8 = _mm_sub_epi8(vsum_uint8, v); //Subtract the comparison result - subtract (-1) where equal.
        }
        //////////////////////////////////////////////////////////////////////////

        vh = _mm_unpackhi_epi8(vsum_uint8, _mm_setzero_si128());        // unpack result into 2 x 8 x 16 bit vectors
        vl = _mm_unpacklo_epi8(vsum_uint8, _mm_setzero_si128());
        vhl = _mm_add_epi16(vh, vl);    //Sum high and low as uint16 elements.

        vhl_hi = _mm_unpackhi_epi16(vhl, _mm_setzero_si128());   //unpack sum of vh an vl into 2 x 4 x 32 bit vectors
        vhl_lo = _mm_unpacklo_epi16(vhl, _mm_setzero_si128());   //unpack sum of vh an vl into 2 x 4 x 32 bit vectors

        vsum = _mm_add_epi32(vsum, vhl_hi);
        vsum = _mm_add_epi32(vsum, vhl_lo);
    }

    // get sum of 4 x 32 bit partial sums
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8));
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4));
    result = _mm_cvtsi128_si32(vsum);

    // handle any residual bytes ( < 16)
    if (j < inner_length)
    {
        result += fast_compare_ref(&s0[j], &t0[j], inner_length - j);
    }

    return result;
}
Rotem
  • 30,366
  • 4
  • 32
  • 65
  • Heh, I should have looked at the new answer before commenting on Paul's; I suggested the same thing (`psubb` inside an inner loop). This is what I meant, except you should use `psadbw` to do the horizontal sum of `vsum_uint8` (see my comments on Paul's answer). – Peter Cordes Jun 21 '16 at 05:32
  • I thought of using horizontal sum, but decided to keep SSE2 compatibility. – Rotem Jun 21 '16 at 17:38
  • Are you talking about `phaddd`? That's not what I said. `phaddd`'s [only advantage is code-size](http://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-float-vector-sum-on-x86/35270026#35270026) on current CPUs. See also my answer on this question, which uses only SSE2 instructions. – Peter Cordes Jun 21 '16 at 17:40
2

The fastest way for large inputs is Rotem's answer, where the inner loop is pcmpeqb / psubb, breaking out to horizontally sum before any byte element of the vector accumulator overflows. Do the hsum of unsigned bytes with psadbw against an all-zero vector.

See also How to count character occurrences using SIMD, where you can use the C++ with intrinsics for AVX2 for counting matches using a vector loaded from another array instead of that question's _mm_set1_epi8(char_to_count). Adding up the compare results efficiently is the same, using psadbw for a horizontal sum.


Without unrolling / nested loops, the best option is probably

pcmpeqb   -> vector of  0  or  0xFF  elements
psadbw    -> two 64bit sums of  (0*no_matches + 0xFF*matches)
paddq     -> accumulate the psadbw result in a vector accumulator

#outside the loop:
horizontal sum
divide the result by 255

If you don't have a lot of register pressure in your loop, psadbw against a vector of 0x7f instead of all-zero.

  • psadbw(0x00, set1(0x7f)) => sum += 0x7f
  • psadbw(0xff, set1(0x7f)) => sum += 0x80

So instead of dividing by 255 (which the compiler should do efficiently without an actual div), you just have to subtract n * 0x7f, where n is the number of elements.

Also note that paddq is slow on pre-Nehalem, and Atom, so you could use paddd (_mm_add_epi32) if you don't expect 128 * the count to ever overflow a 32bit integer.

This compares very well with the Paul R's pcmpeqb / 2x punpck / 2x pmaddwd / 2x paddw.


But with a small unroll, you could accumulate 4 or 8 compare results with psubb before psadbw / paddq.

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
1

The integer comparison in SSE produces bytes that either all zeros or all ones. If you want to count, you first need to right shift (not arithmetic) the comparison result by 7, then add to the result vector. At the end, you still need to reduce the result vector by summing its elements. This reduction has to be done in scalar code, or with a sequence of add/shifts. Usually this part is not worth troubling with.

Photon
  • 3,182
  • 1
  • 15
  • 16