This is a special case of How to count character occurrences using SIMD with c=0
, the char (byte) value to count matches for. See that Q&A for a well-optimized manually-vectorized AVX2 implementation of char_count (char const* vector, size_t size, char c);
with a much tighter inner loop than this, avoiding reducing each vector of 0/-1 matches to scalar separately.
This will go as O(n) so the best you can do is decrease the constant. One quick fix is to remove the branch. This gives a result as fast as my SSE version below if the zeros are randomly distrbuted. This is likely due to the fact the GCC vectorizes this loop. However, for long runs of zeros or for a random density of zeros less than 1% the SSE version below is still faster.
int countZeroBytes_fix(char* values, int length) {
int zeroCount = 0;
for(int i=0; i<length; i++) {
zeroCount += values[i] == 0;
}
return zeroCount;
}
I originally thought that the density of zeros would matter. That turns out not to be the case, at least with SSE. Using SSE is a lot faster independent of the density.
Edit: actually, it does depend on the density it just the density of zeros has to be smaller than I expected. 1/64 zeros (1.5% zeros) is one zero in 1/4 SSE registers so the branch prediction does not work very well. However, 1/1024 zeros (0.1% zeros) is faster (see the table of times).
SIMD is even faster if the data has long runs of zeros.
You can pack 16 bytes into a SSE register. Then you can compare all 16 bytes at once with zero using _mm_cmpeq_epi8
. Then to handle runs of zero you can use _mm_movemask_epi8
on the result and most of the time it will be zero. You could get a speed up of up to 16 in this case (for first half 1 and second half zero I got over a 12X speedup).
Here is a table of times in seconds for 2^16 bytes (with a repeat of 10000).
1.5% zeros 50% zeros 0.1% zeros 1st half 1, 2nd half 0
countZeroBytes 0.8s 0.8s 0.8s 0.95s
countZeroBytes_fix 0.16s 0.16s 0.16s 0.16s
countZeroBytes_SSE 0.2s 0.15s 0.10s 0.07s
You can see the results for last 1/2 zeros at http://coliru.stacked-crooked.com/a/67a169ddb03d907a
#include <stdio.h>
#include <stdlib.h>
#include <emmintrin.h> // SSE2
#include <omp.h>
int countZeroBytes(char* values, int length) {
int zeroCount = 0;
for(int i=0; i<length; i++) {
if (!values[i])
++zeroCount;
}
return zeroCount;
}
int countZeroBytes_SSE(char* values, int length) {
int zeroCount = 0;
__m128i zero16 = _mm_set1_epi8(0);
__m128i and16 = _mm_set1_epi8(1);
for(int i=0; i<length; i+=16) {
__m128i values16 = _mm_loadu_si128((__m128i*)&values[i]);
__m128i cmp = _mm_cmpeq_epi8(values16, zero16);
int mask = _mm_movemask_epi8(cmp);
if(mask) {
if(mask == 0xffff) zeroCount += 16;
else {
cmp = _mm_and_si128(and16, cmp); //change -1 values to 1
//hortiontal sum of 16 bytes
__m128i sum1 = _mm_sad_epu8(cmp,zero16);
__m128i sum2 = _mm_shuffle_epi32(sum1,2);
__m128i sum3 = _mm_add_epi16(sum1,sum2);
zeroCount += _mm_cvtsi128_si32(sum3);
}
}
}
return zeroCount;
}
int main() {
const int n = 1<<16;
const int repeat = 10000;
char *values = (char*)_mm_malloc(n, 16);
for(int i=0; i<n; i++) values[i] = rand()%64; //1.5% zeros
//for(int i=0; i<n/2; i++) values[i] = 1;
//for(int i=n/2; i<n; i++) values[i] = 0;
int zeroCount = 0;
double dtime;
dtime = omp_get_wtime();
for(int i=0; i<repeat; i++) zeroCount = countZeroBytes(values,n);
dtime = omp_get_wtime() - dtime;
printf("zeroCount %d, time %f\n", zeroCount, dtime);
dtime = omp_get_wtime();
for(int i=0; i<repeat; i++) zeroCount = countZeroBytes_SSE(values,n);
dtime = omp_get_wtime() - dtime;
printf("zeroCount %d, time %f\n", zeroCount, dtime);
}