2

Given the following input bytes:

var vBytes = new Vector<byte>(new byte[] {72, 101, 55, 08, 108, 111, 55, 87, 111, 114, 108, 55, 100, 55, 55, 20});

And the given mask:

var mask = new Vector<byte>(55);

How can I find the count of byte 55 in the input array?

I have tried xoring the vBytes with the mask:

var xored = Vector.Xor(mask, vBytes);

which gives:

<127, 82, 0, 91, 91, 88, 0, 96, 88, 69, 91, 0, 83, 0, 0, 35>

But don't know how I can get the count from that.

For the sake of simplicity let's assume that the input byte length is always equal to the size of Vector<byte>.Count.

MaYaN
  • 6,683
  • 12
  • 57
  • 109
  • You mean without simple for loop? – Evk Mar 29 '18 at 10:03
  • 1
    FYI - `Vector.Equals(vBytes,mask)` is probably more intuitive than xor - it returns a vector of 255s/0s. How to count them, though... – Marc Gravell Mar 29 '18 at 10:04
  • @MarcGravell Awesome! I got it!, will update with the answer. – MaYaN Mar 29 '18 at 10:07
  • `Vector.Dot(Vector.Negate(Vector.Equals(vBytes, new Vector(55))), new Vector(1))` would do it. However, I have no experience with SIMD and I don't know if this is a reasonable approach. – Jeroen Mostert Mar 29 '18 at 10:07
  • I guess you could do a "shuffle" to pull a portion of each into a single chunk? but AFAIK `Vector` doesn't expose "shuffle" - you need full avx/avx2 for that (or possibly sse3 for the byte-128 version) – Marc Gravell Mar 29 '18 at 10:09
  • @Evk it is certainly something that should be tested using something like BenchmarkDotNet; I would have expected `.Dot` to use `__m256 _mm256_dp_ps` or similar, no? – Marc Gravell Mar 29 '18 at 10:21
  • @MarcGravell yeah it should, my last comment was nonsense so I deleted it :) Thought still not sure if it will be faster or not, but indeed should be measured. – Evk Mar 29 '18 at 10:22
  • 1
    @MarcGravell: yup, packed byte compare, then use `psadbw` to horizontal-sum those results into 64-bit elements. – Peter Cordes Mar 30 '18 at 04:59

4 Answers4

4

(AVX2 C intrinsics implementation of the below idea, in case a concrete example helps: How to count character occurrences using SIMD)

In asm, you want pcmpeqb to produce a vector of 0 or 0xFF. Treated as signed integers, that's 0/-1.

Then use the compare-result as integers values with psubb to add 0 / 1 to the counter for that element. (Subtract -1 = add +1)

That can overflows after 256 iterations, so sometime before that, use psadbw against _mm_setzero_si128() to horizontally sum those unsigned bytes (without overlow) into 64-bit integers (one 64-bit integer per group of 8 bytes). Then paddq to accumulate 64-bit totals.

Accumulating before you overflow can be done with a nested loop, or just at the end of a regular unrolled loop. psadbw is fast (because it's a key building block for video encoding motion-search), so it's not bad to just accumulate every 4 compares, or even every 1 and skip the psubb.

See Agner Fog's optimization guides for more details on x86. According to his instruction tables, psadbw xmm / vpsadbw ymm runs at 1 vector per clock cycle on Skylake, with 3 cycle latency. (Only 1 uop of front-end bandwidth.) All the instructions mentioned above are also single-uop, and run on more than one port (so don't necessarily conflict with each other for throughput). Their 128-bit versions only require SSE2.


If you really only have one vector at a time to count, and aren't looping over memory, then probably pcmpeqb / psadbw / pshufd (copy high half to low) / paddd / movd eax, xmm0 gives you 255 * number of matches in an integer register. One extra vector instruction (like subtract from zero, or AND with 1, or pabsb (absolute value) would remove the x255 scale factor.


IDK how to write that in C# SIMD, but you definitely do not want a dot-product! Unpack and convert to FP would be about 4x slower than the above, just from the fact that a fixed-width vector holds 4x more bytes than floats, and dpps (_mm_dp_ps) is not fast. 4 uops, and one per 1.5 cycle throughput on Skylake. If you do have to horizontal-sum something other than unsigned bytes, see Fastest way to do horizontal SSE vector sum (or other reduction) (my answer also include integer).

Or if Vector.Dot uses pmaddubsw / pmaddwd for integer vectors, then that might not be as bad, but doing a multi-step horizontal sum for each vector of compare results is just bad compared to psadbw, or especially to byte accumulators that you only horizontal sum occasionally.

Or if C# optimizes out any actual multiplying with a constant vector of 1. Anyway, the first part of this answer is the code you want the CPU to be running. Make that happen however you like using whatever source code gets it to happen.

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
4

I know that I'm super late to the party, but so far none of the answers here actually provide a full solution. Here's my best attempt at one, derived from this Gist and the DotNet source code. All credit goes to the DotNet team and community members here (especially @Peter Cordes).

Usage:

var bytes = Encoding.ASCII.GetBytes("The quick brown fox jumps over the lazy dog.");
var byteCount = bytes.OccurrencesOf(32);

var chars = "The quick brown fox jumps over the lazy dog.";
var charCount = chars.OccurrencesOf(' ');

Code:

public static class VectorExtensions
{
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint GetByteVector128SpanLength(nuint offset, int length) =>
        ((nuint)(uint)((length - (int)offset) & ~(Vector128<byte>.Count - 1)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint GetByteVector256SpanLength(nuint offset, int length) =>
        ((nuint)(uint)((length - (int)offset) & ~(Vector256<byte>.Count - 1)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nint GetCharVector128SpanLength(nint offset, nint length) =>
        ((length - offset) & ~(Vector128<ushort>.Count - 1));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nint GetCharVector256SpanLength(nint offset, nint length) =>
        ((length - offset) & ~(Vector256<ushort>.Count - 1));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector128<byte> LoadVector128(ref byte start, nuint offset) =>
        Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector256<byte> LoadVector256(ref byte start, nuint offset) =>
        Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector128<ushort> LoadVector128(ref char start, nint offset) =>
        Unsafe.ReadUnaligned<Vector128<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector256<ushort> LoadVector256(ref char start, nint offset) =>
        Unsafe.ReadUnaligned<Vector256<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    private static unsafe int OccurrencesOf(ref byte searchSpace, byte value, int length) {
        var lengthToExamine = ((nuint)length);
        var offset = ((nuint)0);
        var result = 0L;

        if (Sse2.IsSupported || Avx2.IsSupported) {
            if (31 < length) {
                lengthToExamine = UnalignedCountVector128(ref searchSpace);
            }
        }

    SequentialScan:
        while (7 < lengthToExamine) {
            ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 3)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 4)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 5)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 6)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 7)) {
                ++result;
            }

            lengthToExamine -= 8;
            offset += 8;
        }

        while (3 < lengthToExamine) {
            ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 3)) {
                ++result;
            }

            lengthToExamine -= 4;
            offset += 4;
        }

        while (0 < lengthToExamine) {
            if (value == Unsafe.AddByteOffset(ref searchSpace, offset)) {
                ++result;
            }

            --lengthToExamine;
            ++offset;
        }

        if (offset < ((nuint)(uint)length)) {
            if (Avx2.IsSupported) {
                if (0 != (((nuint)(uint)Unsafe.AsPointer(ref searchSpace) + offset) & (nuint)(Vector256<byte>.Count - 1))) {
                    var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();

                    offset += 16;
                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                lengthToExamine = GetByteVector256SpanLength(offset, length);

                var searchMask = Vector256.Create(value);

                if (127 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        var accumulator0 = Vector256<byte>.Zero;
                        var accumulator1 = Vector256<byte>.Zero;
                        var accumulator2 = Vector256<byte>.Zero;
                        var accumulator3 = Vector256<byte>.Zero;
                        var loopIndex = ((nuint)0);
                        var loopLimit = Math.Min(255, (lengthToExamine / 128));

                        do {
                            accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
                            accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
                            accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 64))));
                            accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 96))));
                            loopIndex++;
                            offset += 128;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (128 * loopLimit);
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
                    } while (127 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (31 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<byte>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
                        lengthToExamine -= 32;
                        offset += 32;
                    } while (31 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (offset < ((nuint)(uint)length)) {
                    lengthToExamine = (((nuint)(uint)length) - offset);

                    goto SequentialScan;
                }
            }
            else if (Sse2.IsSupported) {
                lengthToExamine = GetByteVector128SpanLength(offset, length);

                var searchMask = Vector128.Create(value);

                if (63 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        var accumulator0 = Vector128<byte>.Zero;
                        var accumulator1 = Vector128<byte>.Zero;
                        var accumulator2 = Vector128<byte>.Zero;
                        var accumulator3 = Vector128<byte>.Zero;
                        var loopIndex = ((nuint)0);
                        var loopLimit = Math.Min(255, (lengthToExamine / 64));

                        do {
                            accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
                            accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
                            accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 32))));
                            accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 48))));
                            loopIndex++;
                            offset += 64;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (64 * loopLimit);
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
                    } while (63 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (15 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
                        lengthToExamine -= 16;
                        offset += 16;
                    } while (15 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (offset < ((nuint)(uint)length)) {
                    lengthToExamine = (((nuint)(uint)length) - offset);

                    goto SequentialScan;
                }
            }
        }

        return ((int)result);
    }
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    private static unsafe int OccurrencesOf(ref char searchSpace, char value, int length) {
        var lengthToExamine = ((nint)length);
        var offset = ((nint)0);
        var result = 0L;

        if (0 != ((int)Unsafe.AsPointer(ref searchSpace) & 1)) { }
        else if (Sse2.IsSupported || Avx2.IsSupported) {
            if (15 < length) {
                lengthToExamine = UnalignedCountVector128(ref searchSpace);
            }
        }

    SequentialScan:
        while (3 < lengthToExamine) {
            ref char current = ref Unsafe.Add(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 3)) {
                ++result;
            }

            lengthToExamine -= 4;
            offset += 4;
        }

        while (0 < lengthToExamine) {
            if (value == Unsafe.Add(ref searchSpace, offset)) {
                ++result;
            }

            --lengthToExamine;
            ++offset;
        }

        if (offset < length) {
            if (Avx2.IsSupported) {
                if (0 != (((nint)Unsafe.AsPointer(ref Unsafe.Add(ref searchSpace, offset))) & (Vector256<byte>.Count - 1))) {
                    var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();

                    offset += 8;
                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                lengthToExamine = GetCharVector256SpanLength(offset, length);

                var searchMask = Vector256.Create(value);

                if (63 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        var accumulator0 = Vector256<ushort>.Zero;
                        var accumulator1 = Vector256<ushort>.Zero;
                        var accumulator2 = Vector256<ushort>.Zero;
                        var accumulator3 = Vector256<ushort>.Zero;
                        var loopIndex = 0;
                        var loopLimit = Math.Min(255, (lengthToExamine / 64));

                        do {
                            accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
                            accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 16))));
                            accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
                            accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 48))));
                            loopIndex++;
                            offset += 64;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (64 * loopLimit);
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
                    } while (63 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (15 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<ushort>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
                        lengthToExamine -= 16;
                        offset += 16;
                    } while (15 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (offset < length) {
                    lengthToExamine = (length - offset);

                    goto SequentialScan;
                }
            }
            else if (Sse2.IsSupported) {
                lengthToExamine = GetCharVector128SpanLength(offset, length);

                var searchMask = Vector128.Create(value);

                if (31 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        var accumulator0 = Vector128<ushort>.Zero;
                        var accumulator1 = Vector128<ushort>.Zero;
                        var accumulator2 = Vector128<ushort>.Zero;
                        var accumulator3 = Vector128<ushort>.Zero;
                        var loopIndex = 0;
                        var loopLimit = Math.Min(255, (lengthToExamine / 32));

                        do {
                            accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
                            accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 8))));
                            accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
                            accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 24))));
                            loopIndex++;
                            offset += 32;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (32 * loopLimit);
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
                    } while (31 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (7 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
                        lengthToExamine -= 8;
                        offset += 8;
                    } while (7 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (offset < length) {
                    lengthToExamine = (length - offset);

                    goto SequentialScan;
                }
            }
        }

        return ((int)result);
    }
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static unsafe nuint UnalignedCountVector128(ref byte searchSpace) {
        nint unaligned = ((nint)Unsafe.AsPointer(ref searchSpace) & (Vector128<byte>.Count - 1));

        return ((nuint)(uint)((Vector128<byte>.Count - unaligned) & (Vector128<byte>.Count - 1)));
    }
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static unsafe nint UnalignedCountVector128(ref char searchSpace) {
        const int ElementsPerByte = (sizeof(ushort) / sizeof(byte));

        return ((nint)(uint)(-(int)Unsafe.AsPointer(ref searchSpace) / ElementsPerByte) & (Vector128<ushort>.Count - 1));
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this ReadOnlySpan<byte> span, byte value) =>
        OccurrencesOf(
            length: span.Length,
            searchSpace: ref MemoryMarshal.GetReference(span),
            value: value
        );
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this Span<byte> span, byte value) =>
        ((ReadOnlySpan<byte>)span).OccurrencesOf(value);
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this ReadOnlySpan<char> span, char value) =>
        OccurrencesOf(
            length: span.Length,
            searchSpace: ref MemoryMarshal.GetReference(span),
            value: value
        );
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this Span<char> span, char value) =>
        ((ReadOnlySpan<char>)span).OccurrencesOf(value);
}
Kittoes0124
  • 4,930
  • 3
  • 26
  • 47
  • 1
    You can avoid the `Avx2.Subtract(Vector256.Zero, accumulator0)` cleanup before PSADBW by using sub instead of add in the first place. i.e. `accumulator -= cmp()` because cmp results are -1 or 0. Also, do `left + right` and then reduce that, instead of separately extracting all 4 elements. – Peter Cordes Sep 27 '21 at 18:36
  • Also, does this work for long buffers with many matches? Maybe I missed it, but I don't see an upper limit on `lengthToExamine` that would prevent a count wrapping from 255 back to 0. You're using 16-bit wide characters, but `psadbw` only sums *bytes*, so unless you split high and low halves of each u16 to count separately, it breaks if any have non-zero high halves. (Given worst-case input, the max length that's guaranteed safe with your AVX2 version is 32 B / vec * 4 vecs * 255 per element = only 32640 bytes before potentially losing counts to overflow, with this hsum method) – Peter Cordes Sep 27 '21 at 18:43
  • @PeterCordes As usual Peter, you're a hero. Lemme review and get back to ya. – Kittoes0124 Sep 27 '21 at 18:44
  • 1
    Cheers. Thanks for writing up an actual C# implementation; I don't really know C# beyond seeing it on SO, so I wasn't going to attempt this. [How to count character occurrences using SIMD](https://stackoverflow.com/q/54541129) uses nested loops to deal with overflow, might want to take a look at exactly how it's implemented. – Peter Cordes Sep 27 '21 at 18:49
  • If you want to let your 16-bit elements get up to counts of 65535, you can hsum them by starting with `pmaddwd` with `set1_epi16(1)`, so we just get an efficient widening horizontal add of pairs, reducing the problem to an AVX2 hsum of 32-bit integer elements, as mentioned in [Fastest way to do horizontal SSE vector sum (or other reduction)](https://stackoverflow.com/a/35270026). (I had a look at that idea of splitting up MSB / LSB halves of 16-bit elements for separate vpsadbw, but it comes out to more instructions than vpmaddwd -> 32-bit hsum.) – Peter Cordes Sep 27 '21 at 19:02
  • But that's only a good idea if 65535 is enough to avoid needing an inner/outer loop at all; Otherwise an inner loop of 128 or 255 iterations with cheap psadbw / paddd or q is good. – Peter Cordes Sep 27 '21 at 19:04
  • @PeterCordes I believe that the cleanup calls and upper limit issue have been properly addressed. Thanks again for your feedback as it has been tremendously helpful. – Kittoes0124 Sep 27 '21 at 20:53
  • It looks like the inner `do{} while (loopIterationNumber < 255);` unconditionally does 64 * 255 bytes, without checking `lengthToExamine`. Outside the loop you could calculate a `max(255, length/64)` or something. (And don't bother decrementing the length and count inside the loop.) – Peter Cordes Sep 27 '21 at 20:59
  • @PeterCordes Damnit, caught that almost immediately after posting and was really hoping to beat you to the punch. You truly are the master, as the calculation outside the loop is obviously more elegant too. – Kittoes0124 Sep 27 '21 at 21:04
  • 1
    Also, in the outer loop, you don't need to reduce to scalar, just to one SIMD vector `var sum` would be fine. The hsum of that to a scalar integer can sink out of the outer loop. (Make sure you're using SIMD addition with 64-bit or at least 32-bit element size for all the accumulation of psadbw results. I guess SumAbsoluteDifferences() returns a `vector` or whatever C# calls it, which implies the element type) – Peter Cordes Sep 27 '21 at 21:04
  • 1
    Also, consider what happens for an input that's say 31 bytes long (or 63x uint16 since you're doing strings not bytes like the question asked). Or n*64 + 31. That's a lot of scalar iterations. That's the downside to unrolling: unless you also provide a not-unrolled vector loop, you make the worst case (including small cases) spend more time in the slow scalar code. If you want to tune this for short to medium strings, you might provide a loop that does one SSE2 vector per iteration, leaving at most 7 leftover elements. – Peter Cordes Sep 27 '21 at 21:08
  • You can even do a 4 element vector with an 8-byte load. Since SSE2 support is near-universal on x86 that could be running modern C# code, maybe do that and let non-SSE2 CPUs fall back to the not-unrolled loop by making the 4-at-a-time case use SSE2 `movq` loads. – Peter Cordes Sep 27 '21 at 21:10
  • Not sure I fully followed what stops big cases from using that loop. Does this do the n % 64 "leftover" elements at the start? That means if you have an aligned buffer with 1025 elements, all the loads in the main vector loop will be misaligned, instead of just the last few. Not a big slowdown with most AVX2 CPUs, but if your data was already hot in cache it could matter more. Generally good to design functions to be good when you pass an aligned buffer, even if the length is odd. – Peter Cordes Sep 27 '21 at 21:16
  • Your goto to re-run that set of loops is interesting, though; reducing the amount of branching needed to even dispatch to the right version of a loop is important too, especially for small problems. (That's the tradeoff memcpy implementations have to make, for example.) So misaligning the pointer could be worth the cost if this trick makes the fast path simpler for common cases. – Peter Cordes Sep 27 '21 at 21:17
  • @PeterCordes The `SequentialScan` label performs data alignment while also taking care of the "leftover" elements. In an ideal setting `UnalignedCountVector128` will return zero and we'll skip straight over both loops into the Avx2/Sse2 section. If that's not possible then we'll immediately hit one or both loops. The branches `if (Avx2.IsSupported)` and `if (Sse2.IsSupported)` can also effectively be ignored because they'll get elided by JIT compilation. Does that address any concerns with this particular approach? Am still trying to figure out how to use `Max(255...`. – Kittoes0124 Sep 27 '21 at 21:33
  • 1
    Oops, meant to write `min(..., 255)` to clamp to 255 as an upper limit on how many inner iterations to do, but do less if you're close to the end of the buffer. I frequently mix up `min` vs. `max` for setting a maximum on some value if I don't stop and think about it. :/ – Peter Cordes Sep 27 '21 at 21:41
  • Re: `UnalignedCountVector128`, ok right, yeah that looks good. Although with modern CPUs having efficient unaligned-load support, it's often *not* worth spending extra instructions to reach an alignment boundary before starting your vector loop. Especially if misaligned is rare. (Which is what makes this advice different from not creating misalignment if the length is odd. And that handling odd length always takes some extra instructions, but doing extra setup work before a loop can be optional.) – Peter Cordes Sep 27 '21 at 21:47
  • Although many buffers end up aligned by 16 but not by 32 (at least with glibc malloc, IDK about C#!), so AVX misalignment isn't rare. – Peter Cordes Sep 27 '21 at 21:48
  • @PeterCordes Testing indicated that it was possible to overflow the outer sum with a large enough stream of occurrences so I added an iteration limit of 16 to the outer loop. Also addressed the 31 and 15 byte problems with a final loop. I -think- that addresses everything. Thanks again for all the guidance. – Kittoes0124 Sep 28 '21 at 05:57
  • If you write it correctly, the outer loop is accumulating into 64-bit counts (in one 256-bit vector) after `psadbw`, using `paddq` (or `paddd`) on each of the 4 accumulators. If you're overflowing, perhaps your C# `var` type isn't 64 or even 32-bit elements, so `.Add` implies `paddw` or `paddb`. Don't have time atm to look over your code. – Peter Cordes Sep 28 '21 at 06:54
  • @PeterCordes No worries. You've been more than helpful enough to get me on the right track. My code was treating them as a series of `ushort` instead of `ulong` and now works as expected with the fix in place. Much obliged sir. – Kittoes0124 Sep 28 '21 at 13:13
2

Here a fast SSE2 implementation in C:

size_t memcount_sse2(const void *s, int c, size_t n) {
   __m128i cv = _mm_set1_epi8(c), sum = _mm_setzero_si128(), acr0,acr1,acr2,acr3;
    const char *p,*pe;                                                                         
    for(p = s; p != (char *)s+(n- (n % (252*16)));) { 
      for(acr0 = acr1 = acr2 = acr3 = _mm_setzero_si128(),pe = p+252*16; p != pe; p += 64) { 
        acr0 = _mm_add_epi8(acr0, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)p))); 
        acr1 = _mm_add_epi8(acr1, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+16)))); 
        acr2 = _mm_add_epi8(acr2, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+32)))); 
        acr3 = _mm_add_epi8(acr3, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+48))));
        __builtin_prefetch(p+1024);
      }
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr0), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr1), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr2), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr3), _mm_setzero_si128()));
    }

    // may require SSE4, rewrite this part for actual SSE2.
    size_t count = _mm_extract_epi64(sum, 0) + _mm_extract_epi64(sum, 1);

    // scalar cleanup.  Could be optimized.
    while(p != (char *)s + n) count += *p++ == c;
    return count;
}

and see: https://gist.github.com/powturbo for and avx2 implementation.

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
powturbo
  • 311
  • 3
  • 3
  • With some compilers, `_mm_extract_epi64(sum, 1)` will only compile with SSE4.1. You could use `_mm_sub_epi8` inside the inner loop to avoid needing to negate the accumulators before `psadbw`. `acr0 -= -1` is the same as `acr0 += 1`. – Peter Cordes Apr 09 '18 at 23:55
  • How much speedup does that prefetch give? On IvyBridge and later, with hardware next-page prefetch, it shouldn't make much difference. – Peter Cordes Apr 09 '18 at 23:56
  • Also, you could do much better for small uneven-size buffers with a cleanup loop that went 1 vector at a time, then maybe 1 `movq`, instead of up to 63 one-byte-at-a-time iterations. Or maybe use a load that goes right up to the *end* of the buffer and mask off the overlapping bytes that you'd double-count. (e.g. load a mask from a sliding window of `...,0,0,0,-1,-1,-1,-1,...`, like this https://stackoverflow.com/questions/34306933/vectorizing-with-unaligned-buffers-using-vmaskmovps-generating-a-mask-from-a-m) – Peter Cordes Apr 10 '18 at 00:00
  • 1
    @PeterCordes thanks for your suggestions. Prefetch speed up is ~10% on i2600k and large buffers. – powturbo Apr 10 '18 at 08:46
  • Interesting. If I get around to it, I'll test on Skylake. (Probably much lower speedup because of next-page prefetching.) I don't have an IvB system, but IvB apparently has some kind of major throughput bottleneck for SW prefetch instructions. – Peter Cordes Apr 10 '18 at 09:04
  • (Update on this old comment thread: "next page" hardware prefetching on Intel is just TLB prefetch, not data. So SW prefetch can sometimes help across page boundaries, if it doesn't slow the loop down getting those extra instructions through the pipeline, reducing the number of iterations the ROB can keep in flight and overlap with earlier / later code.) – Peter Cordes Sep 27 '21 at 22:07
1

Thanks to Marc Gravell for his tip, the following works:

var areEqual = Vector.Equals(vBytes, mask);
var negation = Vector.Negate(areEqual);
var count = Vector.Dot(negation, Vector<byte>.One);

Marc has a blog post with more info on the subject.

MaYaN
  • 6,683
  • 12
  • 57
  • 109
  • nice; easier to read as `var count = Vector.Dot(-Vector.Equals(vBytes, mask), Vector.One);`, but: like it; note: you need to be *really* careful how you "load" the vectors for SIMD; if you aren't careful, you can lose *all* the benefit due to load overhead. `Span` is a great way to load them - raw arrays: not usually so much – Marc Gravell Mar 29 '18 at 10:14
  • Agreed, this was a contrived example to get the core of it working, in production it will be further optimized. Thanks for the light-bulb moment though! – MaYaN Mar 29 '18 at 10:16