1

HiI saw an online answer for counting the distinct prime-factors of a number, and it looked non-optimal. So I tried to improve it, but in a simple benchmark, my variant is much slower than the original.

The algorithm counts the distinct prime factors of a number. The original uses a HashSet to collect the factors, then uses size to get their number. My "improved" version uses an int counter, and breaks up while loops into if/while to avoid unnecessary calls.

Update: tl/dr (see accepted answer for details)

The original code had a performance bug calling Math.sqrt unnecessarily that the compiler fixed:

int n = ...;
// sqrt does not need to be recomputed if n does not change
for (int i = 3; i <= Math.sqrt(n); i += 2) {
    while (n % i == 0) {
        n /= i;
    }
}

The compiler optimized the sqrt call to only happen when n changes. But by making the loop contents a little more complex (no functional change though), the compiler stopped optimizing that way, and sqrt was called on every iteration.

Original question

public class PrimeFactors {

    // fast version, takes 10s for input 8
    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    // slow version, takes 19s for input 8
    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

    static int findNumberWithNPrimeFactors(final int n) {
        for (int i = 3; ; i++) {
            // switch implementations
            if (countPrimeFactorsCounter(i) == n) {
            // if (countPrimeFactorsSet(i) == n) {
                return i;
            }
        }
    }

    public static void main(String[] args) {
        findNumberWithNPrimeFactors(8); // benchmark warmup
        findNumberWithNPrimeFactors(8);
        long start = System.currentTimeMillis();
        int result = findNumberWithNPrimeFactors(n);
        long duration = System.currentTimeMillis() - start;

        System.out.println("took ms " + duration + " to find " + result);
    }
}

The output for the original version is consistently around 10s (on java8), whereas the "optimized" version is closer to 20s (both print the same result). Actually, just changing the single while-loop to an if-block with a contained while-loop already slows down the original method to half the speed.

Using -Xint to run the JVM in interpreted mode, the optimized version runs 3 times faster. Using -Xcomp makes both implementations run at similar speed. So it seems the JIT can optimize the version with a single while-loop and a HashSet more than the version with a simple int counter.

Would a proper microbenchmark (How do I write a correct micro-benchmark in Java?) tell me something else? Is there a performance optimization principle I overlooked (e.g. Java performance tips)?

tkruse
  • 10,222
  • 7
  • 53
  • 80
  • 1
    DId you profile the code? –  Jul 02 '19 at 13:40
  • No, else I would add the results. – tkruse Jul 02 '19 at 13:41
  • Then maybe you should do this. –  Jul 02 '19 at 13:45
  • 1
    @JamesKPolk That's what I thought at first, but I realized that he's calling a method that iterates over every number (starting from `3`) to search for a number with `8` prime factors. – Jacob G. Jul 02 '19 at 13:51
  • You do an extra division just before the while loop; turn that in a **do-while** loop. – Joop Eggen Jul 02 '19 at 14:08
  • Have you tried with inputs larger than 8? Note that for very small inputs, the remainder operation (%) may be run more times in the optimized code than in the original code. You should time this with much larger inputs. – Thomas Bitonti Jul 02 '19 at 14:19
  • Also, it looks like the original implementation mistakenly uses a Set to store prime factors. There is meaning to storing factors more than once, since that records the multiplicity of each factor. Its just that Set is the wrong type to use for storage if this is the goal. – Thomas Bitonti Jul 02 '19 at 14:21
  • @ThomasBitonti I vaguely remember this now. I believe the original goal was to record only those numbers that were divisible by four or more `distinct` factors. That is why the set was used. – WJS Jul 02 '19 at 14:25
  • There are a couple of other problems: The code which is being times should be run several times. More important, the System.out *MUST* be removed from the timed block. My estimate is that the time of the System.out is likely higher than the time of the test code on small inputs. – Thomas Bitonti Jul 02 '19 at 14:30
  • @ThomasBitonti: Did *you* run the code? – President James K. Polk Jul 02 '19 at 14:40
  • Despite the non-pertinent comments and answers, it appears that each method performs the exact same number of integer divides and integer remainder operations. I cannot explain the large timing discrepancy so I will be monitoring this question in hopes of a relevant answer. – President James K. Polk Jul 02 '19 at 14:49
  • For general interest: See https://stackoverflow.com/questions/2267146/what-is-the-fastest-factorization-algorithm, which discusses "best" factorization algorithms. – Thomas Bitonti Jul 02 '19 at 16:32
  • I still don't know why ... but running with the `-Xint` jvm command-line option results in reversal ... the "optimized" code runs significantly faster the code using `Set`. Try it to confirm please. If true then something unexpected is happening from JITing. – President James K. Polk Jul 02 '19 at 22:48
  • @JamesKPolk For me `-Xint` also makes the Set-based algorithm much slower than the int-counter one. @ThomasBitonti, i changed the code to remove the final println out of the tied block. 8 already is a "large" input, as it runs for 10s, iterating the method roughly 10million times. I changed the description to explain it's distinct factors that are computed. – tkruse Jul 02 '19 at 23:04

4 Answers4

4

I converted your example into JMH benchmark to make fair measurements, and indeed the set variant appeared twice as fast as counter:

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5   717,976 ±  7,232  ops/ms
PrimeFactors.set      thrpt    5  1410,705 ± 15,894  ops/ms

To find out the reason, I reran the benchmark with built-in -prof xperfasm profiler. It happened that counter method spent more than 60% time executing vsqrtsd instruction - obviously, the compiled counterpart of Math.sqrt(n).

  0,02%   │  │  │     │  0x0000000002ab8f3e: vsqrtsd %xmm0,%xmm0,%xmm0    <-- Math.sqrt
 61,27%   │  │  │     │  0x0000000002ab8f42: vcvtsi2sd %r10d,%xmm1,%xmm1

At the same time the hottest instruction of the set method was idiv, the result of n % i compilation.

             │  │ ││  0x0000000002ecb9e7: idiv   %ebp               ;*irem
 55,81%      │  ↘ ↘│  0x0000000002ecb9e9: test   %edx,%edx

It's not a surprise that Math.sqrt is a slow operation. But why it was executed more frequently in the first case?

The clue is the transformation of the code you made during optimization. You wrapped a simple while loop into an extra if block. This made the control flow a little more complex, so that JIT failed to hoist Math.sqrt computation out of the loop and had to recompute it on every iteration.

We need to help JIT compiler a bit in order to bring the performance back. Let's hoist Math.sqrt computation out of the loop manually.

    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
            sn = Math.sqrt(n);     // recompute after n changes
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
        for (int i = 3; i <= sn; i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
                sn = Math.sqrt(n);     // recompute after n changes
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

Now counter method became fast! Even a bit faster than set (which is quite expected, because it does the same amount of computation, excluding the Set overhead).

Benchmark              Mode  Cnt     Score    Error   Units
PrimeFactors.counter  thrpt    5  1513,228 ± 13,046  ops/ms
PrimeFactors.set      thrpt    5  1411,573 ± 10,004  ops/ms

Note that set performance did not change, because JIT was able to do the same optimization itself, thanks to a simpler control flow graph.

Conclusion: Java performance is a really complicated thing, especially when talking about micro-optimizations. JIT optimizations are fragile, and it's hard to understand JVM's mind without specialized tools like JMH and profilers.

apangin
  • 92,924
  • 10
  • 193
  • 247
  • Finally, someone who addressed the question. And, even better, answered it! – President James K. Polk Jul 02 '19 at 22:50
  • So the main issue with the original code is that it computes sqrt too often in theory, but the compiler could fix that as long as the rest of the code was simple enough to see whether n can change or not. So the original code has a performance but that is luckily auto-fixed in this case. – tkruse Jul 03 '19 at 00:29
  • @tkruse Exactly – apangin Jul 03 '19 at 07:21
0

First off, there are two sets of operations in the tests: Testing for factors, and recording those factors. When switching up the implementations, using a Set, versus using an ArrayList (in my rewrite, below), versus simply counting the factors will make a difference.

Second off, I'm seeing very large variations in the timings. This is running from Eclipse. I have no clear sense of what is causing the big variations.

My 'lessons learned' is to be mindful of what exactly it being measured. Is the intent to measure the factorization algorithm itself (the cost of the while loops plus the arithmetic operations)? Should time recording the factors be included?

A minor technical point: The lack of multiple-value-setq, which is available in lisp, is keenly felt in this implementation. One would very much rather perform the remainder and integer division as a single operation, rather than writing these out as two distinct steps. From a language and algorithm studies perspective, this is worth looking up.

Here are timing results for three variations of the factorization implementation. The first is from the initial (un-optimized) implementation, but changed to use a simple List instead of a harder to time Set to store the factors. The second is your optimization, but still tracking using a List. The third is your optimization, but including the change to count the factors.

  18 -  3790 1450 2410 (average of 10 iterations)
  64 -  1630 1220  260 (average of 10 iterations)
1091 - 16170 2850 1180 (average of 10 iterations)
1092 -  2720 1370  380 (average of 10 iterations)

4096210 - 28830 5430 9120 (average of  10 iterations, trial 1)
4096210 - 18380 6190 5920 (average of  10 iterations, trial 2)
4096210 - 10072 5816 4836 (average of 100 iterations, trial 1)
4096210 -  7202 5036 3682 (average of 100 iterations, trial 1)

---

Test value [ 18 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621713914872600 (ns) ]
End   [ 1621713914910500 (ns) ]
Delta [ 37900 (ns) ]
Avg   [ 3790 (ns) ]
Factors: [2, 3, 3]
Times [optimized]
Start [ 1621713915343500 (ns) ]
End   [ 1621713915358000 (ns) ]
Delta [ 14500 (ns) ]
Avg   [ 1450 (ns) ]
Factors: [2, 3, 3]
Times [counting]
Start [ 1621713915550400 (ns) ]
End   [ 1621713915574500 (ns) ]
Delta [ 24100 (ns) ]
Avg   [ 2410 (ns) ]
Factors: 3
---
Test value [ 64 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621747046013900 (ns) ]
End   [ 1621747046030200 (ns) ]
Delta [ 16300 (ns) ]
Avg   [ 1630 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [optimized]
Start [ 1621747046337800 (ns) ]
End   [ 1621747046350000 (ns) ]
Delta [ 12200 (ns) ]
Avg   [ 1220 (ns) ]
Factors: [2, 2, 2, 2, 2, 2]
Times [counting]
Start [ 1621747046507900 (ns) ]
End   [ 1621747046510500 (ns) ]
Delta [ 2600 (ns) ]
Avg   [ 260 (ns) ]
Factors: 6
---
Test value [ 1091 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621687024226500 (ns) ]
End   [ 1621687024388200 (ns) ]
Delta [ 161700 (ns) ]
Avg   [ 16170 (ns) ]
Factors: [1091]
Times [optimized]
Start [ 1621687024773200 (ns) ]
End   [ 1621687024801700 (ns) ]
Delta [ 28500 (ns) ]
Avg   [ 2850 (ns) ]
Factors: [1091]
Times [counting]
Start [ 1621687024954900 (ns) ]
End   [ 1621687024966700 (ns) ]
Delta [ 11800 (ns) ]
Avg   [ 1180 (ns) ]
Factors: 1
---
Test value [ 1092 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621619636267500 (ns) ]
End   [ 1621619636294700 (ns) ]
Delta [ 27200 (ns) ]
Avg   [ 2720 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [optimized]
Start [ 1621619636657100 (ns) ]
End   [ 1621619636670800 (ns) ]
Delta [ 13700 (ns) ]
Avg   [ 1370 (ns) ]
Factors: [2, 2, 3, 7, 13]
Times [counting]
Start [ 1621619636895300 (ns) ]
End   [ 1621619636899100 (ns) ]
Delta [ 3800 (ns) ]
Avg   [ 380 (ns) ]
Factors: 5
---
Test value [ 4096210 ]
Warm-up count [ 2 ]
Test count [ 10 ]
Times [non-optimized]
Start [ 1621652753519800 (ns) ]
End   [ 1621652753808100 (ns) ]
Delta [ 288300 (ns) ]
Avg   [ 28830 (ns) ]
Factors: [2, 5, 19, 21559]
Times [optimized]
Start [ 1621652754116300 (ns) ]
End   [ 1621652754170600 (ns) ]
Delta [ 54300 (ns) ]
Avg   [ 5430 (ns) ]
Factors: [2, 5, 19, 21559]
Times [counting]
Start [ 1621652754323500 (ns) ]
End   [ 1621652754414700 (ns) ]
Delta [ 91200 (ns) ]
Avg   [ 9120 (ns) ]
Factors: 4

Here is my rewrite of the test code. Most of interest are findFactors, findFactorsOpt, and findFactorsCount.

package my.tests;

import java.util.ArrayList;
import java.util.List;

public class PrimeFactorsTest {

    public static void main(String[] args) {
        if ( args.length < 2 ) {
            System.out.println("Usage: " + PrimeFactorsTest.class.getName() + " testValue warmupIterations testIterations");
            return;
        }

        int testValue = Integer.valueOf(args[0]);
        int warmCount = Integer.valueOf(args[1]);
        int testCount = Integer.valueOf(args[2]);

        if ( testValue <= 2 ) {
            System.out.println("Test value [ " + testValue + " ] must be at least 2.");
            return;
        } else {
            System.out.println("Test value [ " + testValue + " ]");
        }
        if ( warmCount <= 0 ) {
            System.out.println("Warm-up count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Warm-up count [ " + warmCount + " ]");
        }
        if ( testCount <= 1 ) {
            System.out.println("Test count [ " + testCount + " ] must be at least 1.");
        } else {
            System.out.println("Test count [ " + testCount + " ]");
        }

        timedFactors(testValue, warmCount, testCount);
        timedFactorsOpt(testValue, warmCount, testCount);
        timedFactorsCount(testValue, warmCount, testCount);
    }

    public static void timedFactors(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactors(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [non-optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactors(int n, List<Integer> factors) {
        while ( n % 2 == 0 ) {
            n /= 2;
            factors.add( Integer.valueOf(2) );
        }

        for ( int factor = 3; factor <= Math.sqrt(n); factor += 2 ) {
            while ( n % factor == 0 ) {
                n /= factor;
                factors.add( Integer.valueOf(factor) );
            }
        }

        if ( n > 2 ) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsOpt(int testValue, int warmCount, int testCount) {
        List<Integer> factors = new ArrayList<Integer>();
        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            factors.clear();
            findFactorsOpt(testValue, factors);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [optimized]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + factors);
    }

    public static void findFactorsOpt(int n, List<Integer> factors) {
        if ( n % 2 == 0 ) {
            n /= 2;

            Integer factor = Integer.valueOf(2); 
            factors.add(factor);

            while (n % 2 == 0) {
                n /= 2;

                factors.add(factor);
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;

                Integer factor = Integer.valueOf(factorValue); 
                factors.add(factor);

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    factors.add(factor);
                }
            }
        }

        if (n > 2) {
            factors.add( Integer.valueOf(n) );
        }
    }

    public static void timedFactorsCount(int testValue, int warmCount, int testCount) {
        int numFactors = 0;

        for ( int warmNo = 0; warmNo < warmCount; warmNo++ ) {
            numFactors = findFactorsCount(testValue);
        }

        long startTime = System.nanoTime();
        for ( int testNo = 0; testNo < testCount; testNo++ ) {
            numFactors = findFactorsCount(testValue);
        }
        long endTime = System.nanoTime();

        System.out.println("Times [counting]");
        System.out.println("Start [ " + startTime + " (ns) ]");
        System.out.println("End   [ " + endTime + " (ns) ]");
        System.out.println("Delta [ " + (endTime - startTime) + " (ns) ]");
        System.out.println("Avg   [ " + (endTime - startTime) / testCount + " (ns) ]");
        System.out.println("Factors: " + numFactors);
    }

    public static int findFactorsCount(int n) {
        int numFactors = 0;

        if ( n % 2 == 0 ) {
            n /= 2;
            numFactors++;

            while (n % 2 == 0) {
                n /= 2;
                numFactors++;
            }
        }

        for ( int factorValue = 3; factorValue <= Math.sqrt(n); factorValue += 2) {
            if ( n % factorValue == 0 ) {
                n /= factorValue;
                numFactors++;

                while ( n % factorValue == 0 ) {
                    n /= factorValue;
                    numFactors++;
                }
            }
        }

        if (n > 2) {
            numFactors++;
        }

        return numFactors;
    }
}
Thomas Bitonti
  • 1,179
  • 7
  • 14
  • 1
    The question presented was a simple one: why does `countPrimeFactorsCounter()` appear to take almost twice as much time as `countPrimeFactorsSet()`. It was never about the best or better algorithm to use. I don't see how this answers the question. – President James K. Polk Jul 02 '19 at 18:30
  • Thanks, I can try rerunning with this framework to see other performance patterns. Picking specific numbers as inputs here is a bit confusing because a high number that is a multiple of 2 may need much less code steps than a prime number in the same range, so 1091 and 1092 can have such big differences. Probably a different algorithm could be chosen to measure the same effect where computation scales consitently with the input. – tkruse Jul 02 '19 at 23:23
  • Hi. Yeah, it turns out that the time will strongly depend on the number being factored. I was thinking to try to graph the times for a range of inputs, just to get an initial sense of the range of variation, and maybe how the times depend on the particular factors. One note is that for very small values (like the initial '8') constant factors may dominate the results. That would include Set overhead, which (I'm thinking) will strongly impact the times. – Thomas Bitonti Jul 07 '19 at 22:22
-1

First your block if here : for (int i = 3; i <= Math.sqrt(n); i += 2) { if (n % i == 0) {...

should be out of the loop,

Secondly, you can perform this code with differents methodes like :

while (n % 2 == 0) { Current++; n /= 2; }

you can change it with : if(n % 2 ==0) { current++; n=n%2; }

Essentially, you should avoid conditions or instruction inside loops because of your methode:

(findNumberWithNPrimeFactors)

the complexity of your algorithm is the complexity of each loop (findNumberWithNPrimeFactors) X ( iteration number )

if you add a test or an affectation inside your loop you will get a + 1 ( Complexity (findNumberWithNPrimeFactors) X ( iteration number ) )

-1

The following makes Math.sqrt superfluous, by dividing the n. Continuously comparing with a smaller square root might even be the slowest operation.

Then a do-while would be better style.

static int countPrimeFactorsCounter2(int n) {
    int count = 0; // using simple int
    if (n % 2 == 0) {
        ++count; // only add on first division
        do {
            n /= 2;
        } while (n % 2 == 0);
    }
    for (int i = 3; i <= n; i += 2) {
        if (n % i == 0) {
            count++; // only add on first division
            do {
                n /= i;
            } while (n % i == 0);
        }
    }
    //if (n > 2) {
    //    ++count;
    //}
    return count;
}

The logical fallacy of using the square root is based that with ∀ a, b: a.b = n you only need to try for a < √n. However in an n-dividing loop you save just one single step. Notice that the sqrt is calculated at every odd number i.

Joop Eggen
  • 107,315
  • 7
  • 83
  • 138
  • `Math.sqrt` is not superfluous (e.g. when `n` is prime and there are no divisions in the inner loop). Try to run your "optimized" version - it will take forever to complete. Also, do-while does not make any difference here performance-wise. But most importantly, your answer is not related to the original question, which wasn't about optimizing the algorithm at all, but rather about performance *difference* between two given methods. – apangin Jul 02 '19 at 20:50
  • @apangin `i <= n` should take n being prime also in account. do-while is cosmetics. The sqrt + looping was the main difference; but I admit I did not give a comparison. And your answer is excellent. Thanks for scrutinizing my answer. I did not try my method, but it seems that the extra loop step with % costs. Thanks for that insight. – Joop Eggen Jul 03 '19 at 07:29