38

I have tried to find divisors to potential factorial primes (number of the form n!+-1) and because I recently bought Skylake-X workstation I thought that I could get some speed up using AVX512 instructions.

Algorithm is simple and main step is to take modulo repeatedly respect to same divisor. Main thing is to loop over large range of n values. Here is naïve approach written in c (P is table of primes):

uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
    residue = 2;
    for (n=3; n <= nmax; n++){
        residue *=  n;
        residue %= P[i];
        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1){
                report_factor(n, -1, P[i]);
            }
            if(residue == P[i]- 1){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Here the idea is to check a large range of n, e.g. 1,000,000 -> 10,000,000 against the same set of divisors. So we will take modulo respect to same divisor several million times. using DIV is very slow so there are several possible approaches depending on the range of the calculations. Here in my case n is most likely less than 10^7 and potential divisor p is less than 10,000 G (< 10^13), So numbers are less than 64-bits and also less than 53-bits!, but the product of the maximum residue (p-1) times n is larger than 64-bits. So I thought that simplest version of Montgomery method doesn’t work because we are taking modulo from number that is larger than 64-bit.

I found some old code for power pc where FMA was used to get an accurate product up to 106 bits (I guess) when using doubles. So I converted this approach to AVX 512 assembler (Intel Intrinsics). Here is a simple version of the FMA method, this is based on work of Dekker (1971), Dekker product and FMA version of TwoProduct of that are useful words when trying to find/googling rationale behind this. Also this approach has been discussed in this forum (e.g. here).

int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;

for (i = 0; i < APP_BUFLEN; i++){
    residue = 2.0;
    prime_double = (double)P[i];
    prime_double_reciprocal = 1.0 / prime_double;
    n_double = 3.0;
    for (n=3; n <= nmax; n++){
        nr =  n_double * residue;
        quotient = fma(nr, prime_double_reciprocal, rounding_constant);
        quotient -= rounding_constant;
        prime_times_quotient_high= prime_double * quotient;
        prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
        residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;

        if (residue < 0.0) residue += prime_double;
        n_double += 1.0;

        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1.0){
                report_factor(n, -1, P[i]);
            }
            if(residue == prime_double - 1.0){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Here I have used magic constant

static const double rounding_constant = 6755399441055744.0; 

that is 2^51 + 2^52 magic number for doubles.

I converted this to AVX512 (32 potential divisors per loop) and analyzed result using IACA. It told that Throughput Bottleneck: Backend and Backend allocation was stalled due to unavailable allocation resources. I am not very experienced with assembler so my question is that is there anything I can do to speed this up and solve this backend bottleneck?

AVX512 code is here and can be found also from github

uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA

const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};


uint64_t n;

__m512d zero, rounding_const, one, n_double;

__m512i prime1, prime2, prime3, prime4;

__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;

__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;

// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();

// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);

// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd

// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);

// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);

// residue init
residue1 =  _mm512_set1_pd(2.0);
residue2 =  _mm512_set1_pd(2.0);
residue3 =  _mm512_set1_pd(2.0);
residue4 =  _mm512_set1_pd(2.0);

// double counter init
n_double = _mm512_set1_pd(3.0);

// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000

for (n=3; n<=nmax; n++) // main loop
{

    // timings for instructions:
    // _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
    // _mm512_load_pd = vmovapd : L 1, T 0.5
    // _mm512_set1_pd
    // _mm512_div_pd = vdivpd : L 23, T 16
    // _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5

    // _mm512_mul_pd = vmulpd :  L 4, T 0.5
    // _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd :  L 4, T 0.5
    // _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
    // _mm512_sub_pd = vsubpd : L 4, T 0.5
    // _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
    // _mm512_mask_add_pd = vaddpd :  L 4, T 0.5
    // _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
    // _mm512_kor = korw L 1, T 1

    // nr = residue *  n
    nr1 = _mm512_mul_pd (residue1, n_double);
    nr2 = _mm512_mul_pd (residue2, n_double);
    nr3 = _mm512_mul_pd (residue3, n_double);
    nr4 = _mm512_mul_pd (residue4, n_double);

    // quotient = nr * 1.0/ prime_double + rounding_constant
    quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
    quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
    quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
    quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);

    // quotient -= rounding_constant, now quotient is rounded to integer
    // countient should be at maximum nmax (10,000,000)
    quotient1 = _mm512_sub_pd(quotient1, rounding_const);
    quotient2 = _mm512_sub_pd(quotient2, rounding_const);
    quotient3 = _mm512_sub_pd(quotient3, rounding_const);
    quotient4 = _mm512_sub_pd(quotient4, rounding_const);

    // now we calculate high and low for prime * quotient using decker product (FMA).
    // quotient is calculated using approximation but this is accurate for given quotient
    prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
    prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
    prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
    prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);


    prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
    prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
    prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
    prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);

    // now we calculate new reminder using decker product and using original values
    // we subtract above calculated prime * quotient (quotient is aproximation)

    residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
    residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
    residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
    residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);

    residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
    residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
    residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
    residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);

    // lets check if reminder < 0
    negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
    negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
    negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
    negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);

    // we and prime back to reminder using mask if it was < 0
    residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
    residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
    residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
    residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);

    n_double = _mm512_add_pd(n_double,one);

    // if we are below nmin then we continue next iteration
    if (n < nmin) continue;

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
    { // we find factor very rarely

        double *residual_list1 = (double *) &residue1;
        double *residual_list2 = (double *) &residue2;
        double *residual_list3 = (double *) &residue3;
        double *residual_list4 = (double *) &residue4;

        double *prime_list1 = (double *) &prime_double1;
        double *prime_list2 = (double *) &prime_double2;
        double *prime_list3 = (double *) &prime_double3;
        double *prime_list4 = (double *) &prime_double4;



        for (int i=0; i <8; i++){
            if( residual_list1[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
            }
            if( residual_list2[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
            }
            if( residual_list3[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
            }
            if( residual_list4[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
            }

            if(residual_list1[i] == (prime_list1[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
            }
            if(residual_list2[i] == (prime_list2[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
            }
            if(residual_list3[i] == (prime_list3[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
            }
            if(residual_list4[i] == (prime_list4[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
            }
        }
    }

}

return EXIT_SUCCESS;
}
Óscar Andreu
  • 1,630
  • 13
  • 32
Nuutti
  • 401
  • 3
  • 5
  • 6
    Upvote for a detailed and well-asked question. Welcome to Stack Overflow! – fuz Dec 17 '17 at 14:07
  • 1
    Just out of curiosity, does this `if(residue == prime_double - 1.0)` work reliably (**==**)? It's not obvious to me just by reading source, that the values will stay integer only and within double mantissa limits, so no low digits will be lost. But it may be, depends on `fma` implementation... still feels fragile enough to me, to be worth of extra source comment, why it should work. – Ped7g Dec 17 '17 at 18:33
  • @Ped7g: `fma` is required to not round the internal temporary product, that's most of the point of it existing. It's expensive to implement on hardware that doesn't provide an FMA operation natively. But on x86 with `-march=haswell` or higher, it can inline to a single instruction. – Peter Cordes Dec 17 '17 at 18:56
  • 1
    @Nuutti: A back-end bottlenck on FMA throughput is good, it means you're saturating the FMA throughput of the machine instead of bottlenecking on latency or the front-end. (I think that's what you mean by "allocation resources", but post the IACA summary output.) There will always be a bottleneck of some sort. As far as correctly applying brute-force, FMA throughput (port0 / port5 saturated) is the bottleneck you want to reach. Running faster would require recombining your operations to do more FMA and less add / mul, or otherwise save ops, but that may not be possible with exact results. – Peter Cordes Dec 17 '17 at 19:03
  • (Or of course algorithmic improvements instead of just tweaking the implementation of your existing algorithm). I haven't looked in detail at your source. Hoping to soon, but please post the IACA header output. – Peter Cordes Dec 17 '17 at 19:04
  • 3
    IACA_trace_analysis: https://github.com/NudeSurfer/Factoring/blob/master/IACA_trace_analysis.txt IACA analysis: https://github.com/NudeSurfer/Factoring/blob/master/IACA_analysis.txt – Nuutti Dec 17 '17 at 20:57
  • Looks pretty good; IACA predicts 29c per iteration, with port0/port5 busy with useful work 26.5 of those cycles on average. Of course, the compiler could have done a better job with the `kmovw`. Using `korw` would have to run on a vector ALU port, but so does `kmov`. Using 6 `korw` and one `kortestw` would do the trick, instead of 8 `kmovw`. But there's no way to fix this missed-optimization without writing the loop in asm. I didn't look at the pipeline trace to try to figure out *where* the resource stall happened. – Peter Cordes Dec 18 '17 at 16:28
  • Is IACA correct? Do `perf` counters from testing on real hardware (e.g. ocperf.py) back this up? – Peter Cordes Dec 18 '17 at 16:28
  • 1
    Try increasing your unrolling factor to 8 or even 12. There might actually be enough registers for that. IOW, it might be latency-bound - which will also show up as being "back-end bound". /cc @PeterCordes – Mysticial Dec 18 '17 at 16:45
  • 3
    Also, you don't need to branch that quickly. Assuming that the probability that a particular factor will succeed is extremely low, you can just OR all the masks together and check it once every thousand? iterations? Then if it shows a success, you can re-run the block to find out exactly which factor it is. – Mysticial Dec 18 '17 at 16:47

1 Answers1

1

As a few commenters have suggested: a "backend" bottleneck is what you'd expect for this code. That suggests you're keeping things pretty well fed, which is what you want.

Looking at the report, there should be an opportunity in this section:

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)

From the IACA analysis:

|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r11d, k0
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw eax, k1
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw ecx, k2
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw esi, k3
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw edi, k4
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r8d, k5
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r9d, k6
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r10d, k7
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, eax
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, ecx
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, esi
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, edi
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, r8d
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, r9d
|   1*     |             |      |             |             |      |      |      |      | or r11d, r10d

The processor is moving the resulting comparison masks (k0-k7) over to regular registers for the "or" operation. You should be able to eliminate those moves, AND, do the "or" rollup in 6ops vs 8.

NOTE: the found_factor_mask types are defined as __mmask8, where they should be __mask16 (16x double floats in a 512bit fector). That might let the compiler get at some optimizations. If not, drop to assembly as a commenter noted.

And related: what fraction of iteractions fire this or-mask clause? As another commenter observed, you should be able to unroll this with an accumlating "or" operation. Check the accumulated "or" value at the end of each unrolled iteration (or after N iterations), and if it's "true", go back and re-do the values to figure out which n value triggered it.

(And, you can binary search within the "roll" to find the matching n value -- that might get some gain).

Next, you should be able to get rid of this mid-loop check:

    // if we are below nmin then we continue next iteration, we
    if (n < nmin) continue;

Which shows up here:

|   1*     |             |      |             |             |      |      |      |      | cmp r14, 0x3e8
|   0*F    |             |      |             |             |      |      |      |      | jb 0x229

It may not be a huge gain since the predictor will (probably) get this one (mostly) right, but you should get some gains by having two distinct loops for two "phases":

  • n=3 to n=nmin-1
  • n=nmin and beyond

Even if you gain a cycle, that's 3%. And since that's generally related to the big 'or' operation, above, there may be more cleverness in there to be found.

payne
  • 13,833
  • 5
  • 42
  • 49
  • 1
    removing the branch and separating the loop into two phases probably won't help at all if the code is really *back-end* bound even when it's taken and maybe creating some front-end bubbles. `cmp/jcc` runs on port 6, which doesn't have any vector ALUs. But worth trying, and lower uop throughput will make it slightly more hyperthreading-friendly, at the very minor cost of a slightly larger uop-cache footprint. – Peter Cordes Sep 25 '18 at 00:35