9

Suppose we want to quickly find the index of the first nonzero element in an array, to the effect of

fn leading_zeros(arr: &[u32]) -> Option<usize> {
    arr.iter().position(|&x| x != 0)
}

However, this gets compiled to the one by one check by rustc as seen here. One can speed this up a little bit by checking the words 4 by 4 using the u128 type as follows. This gives a speed up of roughly 3x on my machine.

fn leading_zeros_wide(arr: &[u32]) -> Option<usize> {
    let (beg, mid, _) = unsafe { arr.align_to::<u128>() };

    beg.iter().position(|&x| x != 0).or_else(|| {
        let left = beg.len() + 4 * mid.iter().position(|&x| x != 0).unwrap_or(mid.len());
        arr[left..].iter().position(|&x| x != 0).map(|p| p + left)
    })
}

Is there a way to make this even faster?


Here is a bench I've used to determine the 3x speedup:

#![feature(test)]
extern crate test;

fn v() -> Box<[u32]> {
    std::iter::repeat(0).take(1000).collect()
}

// Assume `leading_zeros` and `leading_zeros_wide` are defined here.

#[bench]
fn bench_leading_zeros(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros(&v[3..]))
}

#[bench]
fn bench_leading_zeros_wide(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros_wide(&v[3..]))
}
MERTON
  • 109
  • 9
  • @Stargateur The line starting with `let left` skips the zeros 4 by 4, by interpreting adjacent 4 words as a single u128. If we cannot skip zeros this way, we fall back to scanning one by one. – MERTON Dec 27 '21 at 18:18
  • 1
    @JohnKugelman I didn't use the `end` parameter because the slice `arr[left..]` contains that part – MERTON Dec 27 '21 at 18:32
  • 1
    @JohnKugelman mm with your explanation I see why end is ignored, I think the code deserve some annotation or better naming of variable. look ok for me now. that said since this question said this is faster, I think having benchmark code to test it in the question would be a plus if not a requirement. – Stargateur Dec 27 '21 at 18:34
  • 7
    I think https://docs.rs/memx/latest/memx/fn.memnechr.html should be faster and more reliable – Stargateur Dec 27 '21 at 19:25
  • Like already said here, user `memchr()`. Other than that, in similar cases, use SIMD. – Chayim Friedman Dec 27 '21 at 20:08
  • 2
    Thanks all! Sadly, `memx` crate appears to have a bug at the moment for `memnechr` (at least for 0.1.18) – MERTON Dec 27 '21 at 20:41
  • 2
    I see that your optimized version still not SIMD even when specified SIMD features as compiler options: https://rust.godbolt.org/z/8scnKToq8 it means that it **can** be optimized further. Apparently there is a way to use CPU intrinsics directly: [x86](https://rust-lang.github.io/stdarch/x86_64/core_arch/x86_64/fn._mm_cmpeq_epi8.html), [arm](https://doc.rust-lang.org/beta/core/arch/arm/fn.vceqq_u8.html). Sorry, I will not provide this solution, I don't know Rust (I see this question by `[simd]` tag) – Alex Guteniev Dec 29 '21 at 20:00
  • 1
    I don't know how to use SIMD intrinsics in Rust, but the asm instructions you want it to emit on x86 are to search for a vector containing a non-zero element, then [Is there an efficient way to get the first non-zero element in an SIMD register using SIMD intrinsics?](https://stackoverflow.com/q/40032906) to find position in that vector. Like my AVX2 C intrinsics answer on [Efficiently find least significant set bit in a large array?](https://stackoverflow.com/q/67605508) (which does a bit-scan on the non-zero element once it finds it, to find the bit-position) – Peter Cordes Dec 29 '21 at 21:19
  • @PeterCordes I'm not worried about the 'last mile', that is finding the `u32` inside a chunk. I just couldn't get `rustc` to vectorize the middle part, which is going to make the most impact – MERTON Dec 29 '21 at 22:13
  • 1
    If you need to manually vectorize anyway, you should definitely do it in a way that gets the element index efficiently. Some ways of looking for a non-zero vector on x86 involve `pcmpeqd` / `movmskps` anyway, so you already have the compare-result bitmap in an integer register just waiting for a bit-scan. – Peter Cordes Dec 29 '21 at 22:16
  • 1
    You probably *do* need to manually vectorize; LLVM and GCC's optimizers don't know how to auto-vectorize loops whose trip-count isn't known ahead of the first iteration. (i.e. search loops or other early-out conditions defeat them.) ICC can auto-vectorize such loops. You could maybe get something from portable code by unconditionally doing 4 `u32` elements in an inner loop, but it's probably hard to get rustc to spit out a simple pcmpeqd / pmovmskb, rather than some silly horizontal reduction. – Peter Cordes Dec 29 '21 at 22:19
  • As you want to run this on aarch64 Please clarify if you want aarch64-specific solution, or a generic solution that does not use architecture-specific intrinsic – Alex Guteniev Jan 01 '22 at 11:25

2 Answers2

4

64 bit: https://rust.godbolt.org/z/rsxh8P8Er

32 bit: https://rust.godbolt.org/z/3P3ejsnh1

I have a little experience with Rust and Assembly but I added some tests.

#[cfg(target_feature = "avx2")]
pub mod avx2 {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    fn first_nonzero_tiny(arr: &[u32]) -> Option<usize> {
        arr.iter().position(|&x| x != 0)
    }

    fn find_u32_zeros_8elems(arr: &[u32], offset: isize) -> i32 {
        unsafe {
            let ymm0 = _mm256_setzero_si256();
            let mut ymm1 = _mm256_loadu_si256(arr.as_ptr().offset(offset) as *const __m256i);
            ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
            let ymm2 = _mm256_castsi256_ps(ymm1);
            _mm256_movemask_ps(ymm2)
        }
    }

    pub fn first_nonzero(arr: &[u32]) -> Option<usize> {
        let size = arr.len();
        if size < 8 {
            return first_nonzero_tiny(arr);
        }

        let mut i: usize = 0;
        let simd_size = size / 8 * 8;
        while i < simd_size {
            let mask: i32 = find_u32_zeros_8elems(&arr, i as isize);
            //println!("mask = {}", mask);
            if mask != 255 {
                return Some((mask.trailing_ones() as usize) + i);
            }
            i += 8;
            //println!("i = {}", i);
        }

        let last_chunk = size - 8;
        let mask: i32 = find_u32_zeros_8elems(&arr, last_chunk as isize);
        if mask != 255 {
            return Some((mask.trailing_ones() as usize) + last_chunk);
        }

        None
    }
}

use avx2::first_nonzero;

pub fn main() {
    let v = [0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [2];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [1, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [0, 1, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(1));

    let v = [0, 0, 1, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(2));

    let v = [0, 0, 0, 1, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(3));

    let v = [0, 0, 0, 0, 1, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(4));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(6));

    let v = [0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(7));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(8));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(16));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(15));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 3, 4, 5];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(14));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(17));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(18));
}
Igor Zhukov
  • 176
  • 1
  • 6
  • Looks good. It should be possible to handle the tail with SIMD, too, for any size>=8, with the last vector being an unaligned load that ends at the end of the array. (`_mm256_loadu_si256`; surprised you use alignment-required `load` in the loop without commenting on that input requirement). (It might be convenient to wrap the SIMD stuff in a helper function, instead of the scalar fallback). – Peter Cordes Dec 31 '21 at 01:15
  • It would be more convenient to handle that if you use `i += 8` and use it as an offset to a pointer-to-i32 before casting the result to `__m256i*`, like C-style `_mm_loadu_si128( (const __m128i*) &arr[i] )` instead of `i + (const __m128i*)arr`. Even for your current code, that would let you use the ending `i` value, instead of `n*8 .. arr.len()`. Although then you'd have to do `i < n-7`; you're solving that vector overshoot problem with `n/=8;` so at this point it's just two equally good styles of doing array indexing for manual SIMD. – Peter Cordes Dec 31 '21 at 01:16
  • Thank you @PeterCordes , I edited my answer and now it uses `_mm256_loadu_si256` instead of `_mm256_load_si256` – Igor Zhukov Dec 31 '21 at 02:01
  • @PeterCordes I refactored my code with your suggestions. – Igor Zhukov Dec 31 '21 at 03:00
  • Nice, that saved an instruction inside the loop, now it's only incrementing `i`, not also updating some other counter for the benefit of code after the loop. A more meaningful name for the helper function might be `find_u32_zeros_8elems` - assigning the result to a variable called `mask` is sufficient reminder that it's a compare bitmask. – Peter Cordes Dec 31 '21 at 03:49
  • rustc's asm isn't really optimal, but fairly reasonable. e.g. could save a uop with a memory-source compare (but only if it avoided an indexed addressing mode for Intel SnB-family CPUs, which would mean more work to calc `i` outside the loop after the pointer increment). And it could avoid duplicating the pcmpeq/movmskps stuff, sharing that between both ways out of the function if it planned regs for it. I was hoping it would use `tzcnt` over `bsf` when BMI was available (`target-feature=+avx2,bmi,bmi2`), but no: https://rust.godbolt.org/z/r5r7z6dG6. – Peter Cordes Dec 31 '21 at 03:50
  • Interesting, `-C target-cpu=haswell` (https://rust.godbolt.org/z/cP4zn5hx1) does get it to use `tzcnt`. But still the useless `movzx` after movmskps, not even going to a different register, defeating mov-elimination. But it does `tzcnt` (or `bsf`) into a different register, causing a false-dependency on RDX, which that path of execution through the function didn't previously write. (BSF always has an output dependency so it can leave dst unmodified on input=0; TZCNT does on SnB-family before Skylake). Anyway, those are rustc / LLVM missed optimizations, nothing you can fix in the src. – Peter Cordes Dec 31 '21 at 03:57
  • I renamed helper function. And now I know that bmi, bmi2 are not meant implicitly when specifying avx2 – Igor Zhukov Dec 31 '21 at 05:41
  • 1
    @PeterCordes, apparently `target-feature=+avx2,bmi,bmi2` doesn't enable `tzcnt`, you have to use `target-feature=+avx2,+bmi` (bmi2 doesn't seem to be needed here) – Alex Guteniev Dec 31 '21 at 09:01
  • @AlexGuteniev: Ah, thanks. I should have checked the manual instead of guessing syntax. (And yes, `tzcnt` is part of BMI1; I threw in BMI2 to see if `shrx` and so on gave rustc anything useful to play with.) – Peter Cordes Dec 31 '21 at 09:04
  • 1
    @IgorZhukov Thank you so much! Will follow up with bench results on my machine (which will take time since I need to port it to aarch64) – MERTON Dec 31 '21 at 13:31
  • @MertSağlam, oh. Porting this to aarch64 means solve this against for aarch64, as manual vectorization implies cpu architecture specific intrinsics. We thought you need x86 due to Godbolt link showing x86. I don't know if it can show aarch64 though – Alex Guteniev Jan 01 '22 at 09:28
  • @Alex Guteniev, It can show aarch64 assembly but I have ever less experience with aarch64 assembly :( Example: https://rust.godbolt.org/z/MWvnK1azz – Igor Zhukov Jan 01 '22 at 09:53
  • I found an example at github which uses rust aarch64 intrisics: https://github.com/3andne/trojan-oxide/blob/main/src/simd/simd_parse.rs And I managed to compile it at godbolt: https://rust.godbolt.org/z/WhhqEzrxq – Igor Zhukov Jan 01 '22 at 10:21
  • @MertSağlam I did a port: https://rust.godbolt.org/z/K9GW6nxds But I didn't test it. It could be completely broken, but maybe it saves some time for you... – Igor Zhukov Jan 01 '22 at 14:06
1

Here is a solution, which is faster than the baseline, but probably still leave a ton on the table.

The following achieves a 7.5x over the baseline first_nonzero.

/// Finds the position of the first nonzero element in a given slice which
/// contains a nonzero.
///
/// # Safety
///
/// The caller *has* to ensure that the input slice has a nonzero.
unsafe fn first_nonzero_padded(arr: &[u32]) -> usize {
    let (beg, mid, _) = arr.align_to::<u128>();
    beg.iter().position(|&x| x != 0).unwrap_or_else(|| {
        let left = beg.len()
            + 4 * {
                let mut p: *const u128 = mid.as_ptr();
                loop {
                    if *p.offset(0) != 0 { break p.offset(0); }
                    if *p.offset(1) != 0 { break p.offset(1); }
                    if *p.offset(2) != 0 { break p.offset(2); }
                    if *p.offset(3) != 0 { break p.offset(3); }
                    if *p.offset(4) != 0 { break p.offset(4); }
                    if *p.offset(5) != 0 { break p.offset(5); }
                    if *p.offset(6) != 0 { break p.offset(6); }
                    if *p.offset(7) != 0 { break p.offset(7); }
                    p = p.offset(8);
                }.offset_from(mid.as_ptr()) as usize
            };
        if let Some(p) = arr[left..].iter().position(|&x| x != 0) {
            left + p
        } else {
            core::hint::unreachable_unchecked()
        }
    })
}
MERTON
  • 109
  • 9
  • Is there a way to compile (on Godbolt) the attempted SIMD version from [the first revision](https://stackoverflow.com/revisions/70533785/1) of this answer? with `use core_simd::u64x2;` and so on? https://godbolt.org/z/E6ozdhdYc didn't work for me with rustc nightly. If it was slower, very likely your `mask8x8::from_array([ *p.offset(00) != ZERO,` up to `07` `])` didn't compile to a single SSE4.1 `pcmpeqq` or whatever. IDK if that would spend a lot of scalar work packing 8x 2-bit compare results into a single `mask8x8`, or worse booleanizing those 2-bit results into 1-bit results? – Peter Cordes Dec 30 '21 at 21:33
  • But anyway, describing that as `pcmpeqd` / `tzcnt` is almost certainly bogus, so yeah no wonder you deleted it from your answer :P And I'm not surprised an early-out on 16-byte chunks is a bit better; you want the inner loop to *not* spend a ton of work preparing for the stuff after the loop that sorts out where the non-zero element was. e.g. if you expect long runs of zeros, you can even OR together multiple vectors, then re-check them individually later. (Working in cache-line sized chunks is good, especially if your data is aligned by 64) – Peter Cordes Dec 30 '21 at 21:34
  • Your current code is doing scalar `OR` of two 64-bit chunks, branching on FLAGS set by that. https://godbolt.org/z/6fMEvveMb – Peter Cordes Dec 30 '21 at 21:40