I'm new to using intrinsics but I wanted to write a function that takes a vector of 4 doubles computes a > 1e-5 ? std::sqrt(a) : 0.0
my first instinct was to write this as follows
#include <immintrin.h>
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
return is_valid * sqrt_val;
}
which according to gcc.godbolt.com compiles to the following
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
but i'm worried what will happen if sqrt_val
contains a nan
. i dont think 0.0 * nan
will work. what are the best practices to do here?
Edit
After reading the comment from @ChrisCooper (and @njuffa) I was linked to another stack overflow answer and so I will test for self equality and then and
this with my result.
#include <immintrin.h>
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
__m256d result = is_valid * sqrt_val;
__m256d cmpeq = result == result;
return _mm256_and_pd(cmpeq, result);
}
which compiles to the following
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
vcmpeqpd ymm1, ymm0, ymm0
vandpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269