9

I need to do the following arithmetic:

long a,b,c;
long result = a*b/c;

While the result is guaranteed to fit in long, the multiplication is not, so it can overflow.

I tried to do it step by step (first multiply and then divide) while dealing with the overflow by splitting the intermediate result of a*b into an int array in size of max 4 ( much like the BigInteger is using its int[] mag variable).

Here I got stuck with the division. I cannot get my head around the bitwise shifts required to do a precise division. All I need is the quotient (don't need the remainder).

The hypothetical method would be:

public static long divide(int[] dividend, long divisor)

Also, I am not considering using BigInteger as this part of the code needs to be fast ( I would like to stick to using primitives and primitive arrays).

Any help would be much appreciated!

Edit: I am not trying to implement the whole BigInteger myself. What I am trying to do is to solve a specific problem (a*b/c, where a*b can overflow) faster than using the generic BigInteger.

Edit2: It would be ideal if it could be done in a clever way, by not getting overflow at all, some tips surfaced in the comments, but I am still looking for one that is correct.

Update: I tried to port BigInteger code to my specific needs, without object creation, and in the first iteration, I got ~46% improvement in speed comparing to using BigInteger (on my development pc).

Then I tried a bit modified @David Eisenstat solution, which gave me ~56 % (I ran 100_000_000_000 random inputs from Long.MIN_VALUE to Long.MAX_VALUE) reduced run times(more than 2x) comparing to BigInteger (that is ~18% compared to my adapted BigInteger algo).

There will be more iterations on optimization and testing, but at this point, I think I must accept this answer as the best.

MagicKriss
  • 150
  • 13
  • 11
    So, you are implementing this yourself the hard way using `int[]` because you are **assuming** that creating objects is slow. This assumption is false, or at least not justified. Also, an array is also an object in Java, so you are now creating one or more `int[]` objects. Instead of basing your code on assumptions, it's better to find **evidence** that your own implementation is faster / more efficient. – Jesper Jan 17 '19 at 09:45
  • Actually, I dont think there is another way, than use bigger datatype, take a look at https://stackoverflow.com/questions/849813/large-numbers-in-java for decimal digits, there is BigDecimal, but yes, I am assuming, it will degrade the performance a little in comparison to long. – xxxvodnikxxx Jan 17 '19 at 09:46
  • You can use [`numberOfLeadingZeros(long)`](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Long.html#numberOfLeadingZeros(long)) and [`numberOfTrailingZeros(long)`](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Long.html#numberOfTrailingZeros(long)) to tell *trivial* from *easy* and *hard*. – greybeard Jan 17 '19 at 09:50
  • 1
    Specific instrument for specific `MulDiv64` task should work better than generic solution with unnecessary operations. [Arbitrary example](https://stackoverflow.com/questions/8733178/most-accurate-way-to-do-a-combined-multiply-and-divide-operation-in-64-bit) – MBo Jan 17 '19 at 10:11
  • 1
    I don't think that this is a duplicate. The OP is not asking for a general-purpose solution but is focused on the operation a*b/c on longs and an ad-hoc solution could be better. –  Jan 17 '19 at 10:23
  • 1
    If the result is guaranteed to fit in `long` then try `long result = ((a/c)*b)+(a%c)*b/c;` – Eritrean Jan 17 '19 at 10:35
  • @Eritrean, this looks much more something suited for my case, but I can see that the result for `a*b/c` is different than the result for `b*a/c` – MagicKriss Jan 17 '19 at 10:47
  • @MagicKriss did you substitute a with b to calculate `b*a/c` ? Meaning `b*a/c = ((b/c)*a)+(b%c)*a/c;` The result should be the same or do i get something wrong? – Eritrean Jan 17 '19 at 11:29
  • @Eritrean is the expression `(a%c)*b` itself guaranteed to fit in `long`? – meowgoesthedog Jan 17 '19 at 11:31
  • @meowgoesthedog I think so. If `a*b/c <= Long.MAX_VALUE` then `(a%c)*b <= Long.MAX_VALUE` since `a%c <= c-1`. This is not a mathimatical proof, just a feeling :-) – Eritrean Jan 17 '19 at 11:40
  • @Eritrean I think you could turn that in an answer – Joop Eggen Jan 17 '19 at 11:41
  • @meowgoesthedog, ah, yes, the problem lies here as we're dealing with overflow again... e.g., if `long b = Long.MAX_VALUE;` then any `(a%c)>1` will lead to overflow in `(a%c)*b` – MagicKriss Jan 17 '19 at 11:42
  • @Eritrean Counter-example: let `a = b = 2^33, c = 2^34 + 1`; `a % c = 2^33` so `(a % c) * b` would still overflow, but of course the expected result `a * b / c = 2^31` would not. – meowgoesthedog Jan 17 '19 at 11:45
  • @meowgoesthedog get it. Thank you for your example. Will think about it and will post an answer if i get that solved. – Eritrean Jan 17 '19 at 11:54
  • 1
    @MagicKriss you may find [this post](https://codereview.stackexchange.com/questions/67962/mostly-portable-128-by-64-bit-division) useful. – meowgoesthedog Jan 17 '19 at 12:11
  • This [How to multiply a 64 bit integer by a fraction](https://stackoverflow.com/a/25183354) is what I thought of as a safe starting place. – greybeard Jan 17 '19 at 13:50
  • `[assuming 57bit-ish FP-arithmetic] is resolving the precision issue [in (63*63)/63 bit]. Am I missing something?` I think you are, unless you disclose the likes of *last 1…3 bits do not need to be correct*, *error below 1‰ is sufficient*, *c < 2\*\*32* or *a×b<2\*\*65*. – greybeard Jan 17 '19 at 23:55

5 Answers5

3

I've been tinkering with an approach that (1) multiplies a and b with the school algorithm on 21-bit limbs (2) proceeds to do long division by c, with an unusual representation of the residual a*b - c*q that uses a double to store the high-order bits and a long to store the low-order bits. I don't know if it can be made to be competitive with standard long division, but for your enjoyment,

public class MulDiv {
  public static void main(String[] args) {
    java.util.Random r = new java.util.Random();
    for (long i = 0; true; i++) {
      if (i % 1000000 == 0) {
        System.err.println(i);
      }
      long a = r.nextLong() >> (r.nextInt(8) * 8);
      long b = r.nextLong() >> (r.nextInt(8) * 8);
      long c = r.nextLong() >> (r.nextInt(8) * 8);
      if (c == 0) {
        continue;
      }
      long x = mulDiv(a, b, c);
      java.math.BigInteger aa = java.math.BigInteger.valueOf(a);
      java.math.BigInteger bb = java.math.BigInteger.valueOf(b);
      java.math.BigInteger cc = java.math.BigInteger.valueOf(c);
      java.math.BigInteger xx = aa.multiply(bb).divide(cc);
      if (java.math.BigInteger.valueOf(xx.longValue()).equals(xx) && x != xx.longValue()) {
        System.out.printf("a=%d b=%d c=%d: %d != %s\n", a, b, c, x, xx);
      }
    }
  }

  // Returns truncate(a b/c), subject to the precondition that the result is
  // defined and can be represented as a long.
  private static long mulDiv(long a, long b, long c) {
    // Decompose a.
    long a2 = a >> 42;
    long a10 = a - (a2 << 42);
    long a1 = a10 >> 21;
    long a0 = a10 - (a1 << 21);
    assert a == (((a2 << 21) + a1) << 21) + a0;
    // Decompose b.
    long b2 = b >> 42;
    long b10 = b - (b2 << 42);
    long b1 = b10 >> 21;
    long b0 = b10 - (b1 << 21);
    assert b == (((b2 << 21) + b1) << 21) + b0;
    // Compute a b.
    long ab4 = a2 * b2;
    long ab3 = a2 * b1 + a1 * b2;
    long ab2 = a2 * b0 + a1 * b1 + a0 * b2;
    long ab1 = a1 * b0 + a0 * b1;
    long ab0 = a0 * b0;
    // Compute a b/c.
    DivBy d = new DivBy(c);
    d.shift21Add(ab4);
    d.shift21Add(ab3);
    d.shift21Add(ab2);
    d.shift21Add(ab1);
    d.shift21Add(ab0);
    return d.getQuotient();
  }
}

public strictfp class DivBy {
  // Initializes n <- 0.
  public DivBy(long d) {
    di = d;
    df = (double) d;
    oneOverD = 1.0 / df;
  }

  // Updates n <- 2^21 n + i. Assumes |i| <= 3 (2^42).
  public void shift21Add(long i) {
    // Update the quotient and remainder.
    q <<= 21;
    ri = (ri << 21) + i;
    rf = rf * (double) (1 << 21) + (double) i;
    reduce();
  }

  // Returns truncate(n/d).
  public long getQuotient() {
    while (rf != (double) ri) {
      reduce();
    }
    // Round toward zero.
    if (q > 0) {
      if ((di > 0 && ri < 0) || (di < 0 && ri > 0)) {
        return q - 1;
      }
    } else if (q < 0) {
      if ((di > 0 && ri > 0) || (di < 0 && ri < 0)) {
        return q + 1;
      }
    }
    return q;
  }

  private void reduce() {
    // x is approximately r/d.
    long x = Math.round(rf * oneOverD);
    q += x;
    ri -= di * x;
    rf = repairLowOrderBits(rf - df * (double) x, ri);
  }

  private static double repairLowOrderBits(double f, long i) {
    int e = Math.getExponent(f);
    if (e < 64) {
      return (double) i;
    }
    long rawBits = Double.doubleToRawLongBits(f);
    long lowOrderBits = (rawBits >> 63) ^ (rawBits << (e - 52));
    return f + (double) (i - lowOrderBits);
  }

  private final long di;
  private final double df;
  private final double oneOverD;
  private long q = 0;
  private long ri = 0;
  private double rf = 0;
}
David Eisenstat
  • 64,237
  • 7
  • 60
  • 120
  • I have tinkered with *31.5* bit limbs - not plug-compatible with `DivBy.reduce()`, I have a hunch why (31.5 bit: base `ceil(Math.sqrt(Long.MAX_VALUE))`. Get distracted coding "cleverness" … – greybeard Jan 20 '19 at 21:58
  • @greybeard It's frustrating, isn't it? Between long and double, we can get an 64 + ~50 bit product of 64-bit quantities, which isn't quite enough here. – David Eisenstat Jan 20 '19 at 22:03
  • Frustrating indeed: the bottom line seems to be that for worst case you need to implement *64×64→128 bit* and *128/64 bit*. The sole promise of reduced effort (compared to, .g., `BigInteger`) is that you can special case these few limbs and avoid looping. Java 1.8 has grown a [`long divideUnsigned​(long dividend, long divisor)`](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Long.html#remainderUnsigned(long,long)) and assorted `remainder()` - I'm anything but hell-bent on coding *128/64 bit* regardless. – greybeard Jan 21 '19 at 00:42
1

You can use the greatest common divisor (gcd) to help.

a * b / c = (a / gcd(a,c)) * (b / (c / gcd(a,c)))

Edit: The OP asked me to explain the above equation. Basically, we have:

a = (a / gcd(a,c)) * gcd(a,c)
c = (c / gcd(a,c)) * gcd(a,c)

Let's say x=gcd(a,c) for brevity, and rewrite this.

a*b/c = (a/x) * x * b 
        --------------
        (c/x) * x

Next, we cancel

a*b/c = (a/x) * b 
        ----------
        (c/x) 

You can take this a step further. Let y = gcd(b, c/x)

a*b/c = (a/x) * (b/y) * y 
        ------------------
        ((c/x)/y) * y 

a*b/c = (a/x) * (b/y) 
        ------------
           (c/(xy))

Here's code to get the gcd.

static long gcd(long a, long b) 
{ 
  if (b == 0) 
    return a; 
  return gcd(b, a % b);  
} 
Dave
  • 7,460
  • 3
  • 26
  • 39
  • Thank's for your answer. While it works **most** of the time, there are some cases where it fails. E.g. `a = 13`, `b = 2` and `c = 12` ( the answer from this algo is `0`, while the correct one is `2`) I am starting to wonder if it is even possible not to overflow and still get correct answers all of the time, as one would lose precision without overflow. – MagicKriss Jan 17 '19 at 20:33
  • I am wondering, what if I modify this answer a bit and get intermediate result as double, like so: `a * b / c = ((double)a / gcd(a,c)) * (b / ((double)c / gcd(a,c)))` Would that work? Or are there some underlying errors that I am not seeing? – MagicKriss Jan 17 '19 at 23:36
  • 1
    @MagicKriss I was assuming that a * b is divisible by c. If that's not the case then this will give a different answer sometimes because of rounding. – Dave Jan 17 '19 at 23:38
  • I am curions, how did you get from `a * b / c` to `(a / gcd(a,c)) * (b / (c / gcd(a,c)))` ? Could you show the steps in between? – MagicKriss Jan 18 '19 at 07:44
  • I was wrong in my assumptions, the code with GDC would be wrong even in some cases where there is no overflow at all ( even with casting to `double`) – MagicKriss Jan 18 '19 at 13:30
  • 1
    @MagicKriss It gives the wrong answer when you round due to integer division because of the order we're doing things. If you calculate the full numerator then divide (see above equations) it should always give the right answer, but in that case you risk overflow. – Dave Jan 18 '19 at 13:41
  • Seeing how `b / (c / gcd(a,c))` provokes plain wrong answers, using the compact form `(a/x) * (b/y) / (c/(x*y))` from your *step further* seems prudent. Doesn't prevent overflow from GCDs particularly small or large. Oh, wait, `(a/x) * (b/y) / ((c/x)/y)` (simplify as much as possible, but…) shouldn't suffer from large GCDs, leaving "the coprime problem" (unsolved, that is). – greybeard Jan 19 '19 at 23:58
1

David Eisenstat got me thinking some more.
I want simple cases to be fast: let double take care of that. Newton-Raphson may be a better choice for the rest.

 /** Multiplies both <code>factor</code>s
  *  and divides by <code>divisor</code>.
  * @return <code>Long.MIN_VALUE</code> if result out of range,<br/>
  *     else <code>factorA * factor1 / divisor</code> */
    public static long
    mulDiv(long factorA, long factor1, long divisor) {
        final double dd = divisor,
            product = (double)factorA * factor1,
            a1_d = product / dd;
        if (a1_d < -TOO_LARGE || TOO_LARGE < a1_d)
            return tooLarge();
        if (-ONE_ < a1_d && a1_d < ONE_)
            return 0;
        if (-EXACT < product && product < EXACT)
            return (long) a1_d;
        long pLo = factorA * factor1, //diff,
            pHi = high64(factorA, factor1);
        if (a1_d < -LONG_MAX_ || LONG_MAX_ < a1_d) {
            long maxdHi = divisor >> 1;
            if (maxdHi < pHi
                || maxdHi == pHi
                   && Long.compareUnsigned((divisor << Long.SIZE-1),
                                           pLo) <= 0)
                return tooLarge();
        }
        final double high_dd = TWO_POWER64/dd;
        long quotient = (long) a1_d,
            loPP = quotient * divisor,
            hiPP = high64(quotient, divisor);
        long remHi = pHi - hiPP, // xxx overflow/carry
            remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        double fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        long //fHi = (long)fudge/TWO_POWER64,
            fLo = (long) Math.floor(fudge); //*round
        quotient += fLo;
        loPP = quotient * divisor;
        hiPP = high64(quotient, divisor);
        remHi = pHi - hiPP; // should be 0?!
        remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        if (0 == remHi && 0 <= remLo && remLo < divisor)
            return quotient;

        fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        fLo = (long) Math.floor(fudge);
        return quotient + fLo;
    }

 /** max <code>double</code> trusted to represent
  *  a value in the range of <code>long</code> */
    static final double
        LONG_MAX_ = Double.valueOf(Long.MAX_VALUE - 0xFFF);
 /** max <code>double</code> trusted to represent a value below 1 */
    static final double
        ONE_ = Double.longBitsToDouble(
                    Double.doubleToRawLongBits(1) - 4);
 /** max <code>double</code> trusted to represent a value exactly */
    static final double
        EXACT = Long.MAX_VALUE >> 12;
    static final double
        TWO_POWER64 = Double.valueOf(1L<<32)*Double.valueOf(1L<<32);

    static long tooLarge() {
//      throw new RuntimeException("result too large for long");
        return Long.MIN_VALUE;
    }
    static final long   ONES_32 = ~(~0L << 32);

    static long high64(long factorA, long factor1) {
        long loA = factorA & ONES_32,
            hiA = factorA >>> 32,
            lo1 = factor1 & ONES_32,
            hi1 = factor1 >>> 32;
        return ((loA * lo1 >>> 32)
                +loA * hi1 + hiA * lo1 >>> 32)
               + hiA * hi1;
    }

(I rearranged this code some out of the IDE to have mulDiv() on top. Being lazy, I have a wrapper for sign handling - might try and do it properly before hell freezes over.
For timing, a model of input is in dire need:
How about such that each result possible is equally likely?)

greybeard
  • 2,249
  • 8
  • 30
  • 66
  • From Java 9, `high64()` should be substituted by [`Math.multiplyHigh(long, long)`](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Math.html#multiplyHigh(long,long)). (Tried to update my IDE to 11: waded through pages on end of legalse (89 p. licence manual? Kid me.). Didn't get Eclipse JDT to accept the openJDK (yet)…) – greybeard Jan 27 '19 at 14:55
0

Perhaps not clever, but has linear result time

#define MUL_DIV_TYPE    unsigned int
#define BITS_PER_TYPE   (sizeof(MUL_DIV_TYPE)*8)
#define TOP_BIT_TYPE    (1<<(BITS_PER_TYPE-1))

//
//    result = ( a * b ) / c, without intermediate overflow.
//
MUL_DIV_TYPE mul_div( MUL_DIV_TYPE a, MUL_DIV_TYPE b, MUL_DIV_TYPE c ) {
    MUL_DIV_TYPE    st, sb;     // product sum top and bottom

    MUL_DIV_TYPE    d, e;       // division result

    MUL_DIV_TYPE    i,      // bit counter
            j;      // overflow check

    st = 0;
    sb = 0;

    d = 0;
    e = 0;

    for( i = 0; i < BITS_PER_TYPE; i++ ) {
        //
        //  Shift sum left to make space
        //  for next partial sum
        //
        st <<= 1;
        if( sb & TOP_BIT_TYPE ) st |= 1;
        sb <<= 1;
        //
        //  Add a to s if top bit on b
        //  is set.
        //
        if( b & TOP_BIT_TYPE ) {
            j = sb;
            sb += a;
            if( sb < j ) st++;
        }
        //
        //  Division.
        //
        d <<= 1;
        if( st >= c ) {
            d |= 1;
            st -= c;
            e++;
        }
        else {
            if( e ) e++;
        }
        //
        //  Shift b up by one bit.
        //
        b <<= 1;
    }
    //
    //  Roll in missing bits.
    //
    for( i = e; i < BITS_PER_TYPE; i++ ) {
        //
        //  Shift across product sum
        //
        st <<= 1;
        if( sb & TOP_BIT_TYPE ) st |= 1;
        sb <<= 1;
        //
        //  Division, continued.
        //
        d <<= 1;
        if( st >= c ) {
            d |= 1;
            st -= c;
        }
    }
    return( d );  // remainder should be in st
}
-1

Divide a/c and b/c into whole and fractional (remainder) parts, then you have:

a*b/c 
= c * a/c * b/c 
= c * (x/c + y/c) * (z/c + w/c)
= xz/c + xw/c + yz/c + yw/c where x and z are multiples of c

As such, you can trivially calculate the first three factors without overflow. In my experience, this is often enough to cover typical overflow cases. However, if your divisor is too large, such that (a % c) * (b % c) overflows, this method still fails. If that's a typical issue for you, you may want to look at other approaches (e.g. dividing both the biggest of a and b as well as c by 2 until you have no overflows anymore, but how to do that without introducing additional error due to biases in the process is non-trivial -- you'll need to keep a running score of the error in a separate variable, probably)

Anyway, the code for the above:

long a,b,c;
long bMod = (b % c)
long result = a * (b / c) + (a / c) * bMod + ((a % c) * bMod) / c;

If speed is a big concern (I'm assuming it is at least to some extent, since you're asking this), you may want to consider storing a/c and b/c in variables and calculating the mod through multiplication, e.g. replace (a % c) by (a - aDiv * c) -- this allows you to go from 4 divisions per call to 2.

tohoho
  • 339
  • 1
  • 10