4

I recently implemented Karatsuba Multiplication as a personal exercise. I wrote my implementation in Python following the pseudocode provided on wikipedia:

procedure karatsuba(num1, num2)
if (num1 < 10) or (num2 < 10)
    return num1*num2
  /* calculates the size of the numbers */
  m = max(size_base10(num1), size_base10(num2))
  m2 = m/2
  /* split the digit sequences about the middle */
  high1, low1 = split_at(num1, m2)
  high2, low2 = split_at(num2, m2)
  /* 3 calls made to numbers approximately half the size */
  z0 = karatsuba(low1, low2)
  z1 = karatsuba((low1+high1), (low2+high2))
  z2 = karatsuba(high1, high2)
  return (z2*10^(2*m2)) + ((z1-z2-z0)*10^(m2)) + (z0)

Here is my python implementation:

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m / 2

        a = x / 10**(m2)
        b = x % 10**(m2)
        c = y / 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)

My question is about final merge of z0, z1, and z2.
z2 is shifted m digits over (where m is the length of the largest of two multiplied numbers).
Instead of simply multiplying by 10^(m), the algorithm uses *10^(2*m2)* where m2 is m/2.

I tried replacing 2*m2 with m and got incorrect results. I think this has to do with how the numbers are split but I'm not really sure what's going on.

greybeard
  • 2,249
  • 8
  • 30
  • 66
Solomon Bothwell
  • 1,004
  • 2
  • 12
  • 21

8 Answers8

13

Depending on your Python version you must or should replace / with the explicit floor division operator // which is the appropriate here; it rounds down ensuring that your exponents remain entire numbers.

This is essential for example when splitting your operands in high digits (by floor dividing by 10^m2) and low digits (by taking the residual modulo 10^m2) this would not work with a fractional m2.

It also explains why 2 * (x // 2) does not necessarily equal x but rather x-1 if x is odd. In the last line of the algorithm 2 m2 is correct because what you are doing is giving a and c their zeros back.

If you are on an older Python version your code may still work because / used to be interpreted as floor division when applied to integers.

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m // 2

        a = x // 10**(m2)
        b = x % 10**(m2)
        c = y // 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)
Paul Panzer
  • 51,835
  • 3
  • 54
  • 99
  • 1
    A little off the topic, but why do you need to put "else" after if? If condition is true it will return so there is no need for "else" – Pedram Mar 14 '18 at 17:12
  • @Pedram Matter of taste I'd say. It visually emphasises the logical symmetry and makes some patterns of modification - say you were to decide at some point you need to postprocess the results of all branches easier. Of course, one might argue that here it's not symmetric and the first branch is just a corner case to get over with. As I said, a matter of taste. – Paul Panzer Mar 14 '18 at 17:57
  • Is Karatsuba multiplication supposed to work with two negative numbers as input? If it is, your solution reaches max recursion depth. –  Oct 09 '19 at 19:59
  • 4
    @Saurabh Good catch! This probably doesn't work because Python's floor division rounds down (I think the C equivalent rounds towards zero). I suppose the simplest solution would be to strip the signs off in the beginning and put whatever is appropriate back in the end. – Paul Panzer Oct 09 '19 at 20:43
2

i have implemented the same idea but i have restricted to the 2 digit multiplication as the base case because i can reduce float multiplication in function

import math

def multiply(x,y):
    sx= str(x)
    sy= str(y)
    nx= len(sx)
    ny= len(sy)
    if ny<=2 or nx<=2:
        r = int(x)*int(y)
        return r
    n = nx
    if nx>ny:
        sy = sy.rjust(nx,"0")
        n=nx
    elif ny>nx:
        sx = sx.rjust(ny,"0")
        n=ny
    m = n%2
    offset = 0
    if m != 0:
        n+=1
        offset = 1
    floor = int(math.floor(n/2)) - offset
    a = sx[0:floor]
    b = sx[floor:n]
    c = sy[0:floor]
    d = sy[floor:n]
    print(a,b,c,d)

    ac = multiply(a,c)
    bd = multiply(b,d)

    ad_bc = multiply((int(a)+int(b)),(int(c)+int(d)))-ac-bd
    r = ((10**n)*ac)+((10**(n/2))*ad_bc)+bd

    return r

print(multiply(4,5))
print(multiply(4,58779))
print(int(multiply(4872139874092183,5977098709879)))
print(int(4872139874092183*5977098709879))
print(int(multiply(4872349085723098457,597340985723098475)))
print(int(4872349085723098457*597340985723098475))
print(int(multiply(4908347590823749,97098709870985)))
print(int(4908347590823749*97098709870985))
  • 2
    I don't understand `because I can reduce float multiplication in function`: what goal are you trying to achieve in the first place using floating point arithmetic? – greybeard Jan 02 '18 at 17:46
1

I tried replacing 2*m2 with m and got incorrect results. I think this has to do with how the numbers are split but I'm not really sure what's going on.

This goes to the heart of how you split your numbers for the recursive calls. If you choose to use an odd n then n//2 will be rounded down to the nearest whole number, meaning your second number will have a length of floor(n/2) and you would have to pad the first with the floor(n/2) zeros. Since we use the same n for both numbers this applies to both. This means if you stick to the original odd n for the final step, you would be padding the first term with the original n zeros instead of the number of zeros that would result from the combination of the first padding plus the second padding (floor(n/2)*2)

karuhanga
  • 3,010
  • 1
  • 27
  • 30
1

You have used m2 as a float. It needs to be an integer.

def karat(x,y):
    if len(str(x)) == 1 or len(str(y)) == 1:
        return x*y
    else:
        m = max(len(str(x)),len(str(y)))
        m2 = m // 2

        a = x // 10**(m2)
        b = x % 10**(m2)
        c = y // 10**(m2)
        d = y % 10**(m2)

        z0 = karat(b,d)
        z1 = karat((a+b),(c+d))
        z2 = karat(a,c)

        return (z2 * 10**(2*m2)) + ((z1 - z2 - z0) * 10**(m2)) + (z0)

  • 1
    Please don't post only code as answer, but also provide an explanation what your code does and how it solves the problem of the question. Answers with an explanation are usually more helpful and of better quality, and are more likely to attract upvotes. – Dima Kozhevin Jul 30 '20 at 13:35
0

Your code and logic is correct, there is just issue with your base case. Since according to the algo a,b,c,d are 2 digit numbers you should modify your base case and keep the length of x and y equal to 2 in the base case.

  • Can you elaborate? This answer isn't entirely clear. – CertainPerformance Jul 28 '18 at 21:44
  • In your code x and y are 4 digit number and you know you have divided x as (10^(n/2))a+b that is if x=5678, a=56 and b=78 that is x=((10^2)a)+b this makes the length of string in our base case= 2. And also change your return statement as follows .. ((10**(2*m2))*z2)+((10**m2)*(z1-z2-z0))+z0 Nothing wrong with your return expression just some missing brackets which I think can cause a change is calculation. I hope you have understood what is wrong with the base case................. – Naira Rahim Jul 30 '18 at 19:28
  • Don't post essential information in a comment, put it in your answer instead. Comments are transient; answers are not. – CertainPerformance Jul 30 '18 at 19:49
0

I think it is better if you used math.log10 function to calculate the number of digits instead of converting to string, something like this :

def number_of_digits(number):
  """
  Used log10 to find no. of digits
  """
  if number > 0:
    return int(math.log10(number)) + 1
  elif number == 0:
    return 1
  else:
    return int(math.log10(-number)) + 1 # Don't count the '-'
Mu-Majid
  • 851
  • 1
  • 9
  • 16
  • I agree, although the base case makes it so that we'll never see n==0. From there, we might as well just use `abs` to take the absolute value anyway. So the body could just be: `return int(math.log10(abs(number)))+1` – Austin Cory Bart Mar 05 '23 at 02:35
0

I wanted a version that could fit nicely on a single slide, without sacrificing clarity. Here's what I came up with:

def number_of_digits(number: int) -> int:
    return int(math.log10(abs(number))) + 1

def multiply(x: int, y: int) -> int:
    # We CAN multiply small numbers
    if abs(x) < 10 or abs(y) < 10:
        return x * y

    # Calculate the size of the numbers
    digits = max(number_of_digits(x), number_of_digits(y))
    midpoint = 10 ** (digits // 2)

    # Split digit sequences in the middle
    high_x = x // midpoint
    low_x = x % midpoint
    high_y = y // midpoint
    low_y = y % midpoint

    # 3 recursive calls to numbers approximately half the size
    z0 = multiply(low_x, low_y)
    z1 = multiply(low_x + high_x, low_y + high_y)
    z2 = multiply(high_x, high_y)

    return (z2 * midpoint**2) + ((z1 - z2 - z0) * midpoint) + (z0

print(multiply(2**100, 3**100))

I'd argue that:

  1. the variable names are clearer
  2. the number_of_digits helper function should be using log10 to find the number of digits, instead of str+len
  3. The math is a little clearer by extracting out the 10**digits//2 term.
Austin Cory Bart
  • 2,159
  • 2
  • 19
  • 32
-1

The base case if len(str(x)) == 1 or len(str(y)) == 1: return x*y is incorrect. If you run either of the python code given in answers against large integers, the karat() function will not produce the correct answer.

To make the code correct, you need to change the base case to if len(str(x) < 3 or len(str(y)) < 3: return x*y.

Below is a modified implementation of Paul Panzer's answer that correctly multiplies large integers.

def karat(x,y):
    if len(str(x)) < 3 or len(str(y)) < 3:
        return x*y

    n = max(len(str(x)),len(str(y))) // 2

    a = x // 10**(n)
    b = x % 10**(n)
    c = y // 10**(n)
    d = y % 10**(n)

    z0 = karat(b,d)
    z1 = karat((a+b), (c+d))
    z2 = karat(a,c)

    return ((10**(2*n))*z2)+((10**n)*(z1-z2-z0))+z0
micarlise
  • 106
  • 5
  • 1
    I can't really see an improvement in your code. You just replaced the `==1`-conditions in the if-clause and as I can see, nothing else changed. Have I overseen something really important? Or is the change from `==1` to `<3` the major part. Then, some explanations about this would be fine. – colidyre Aug 19 '18 at 13:03
  • @colidyre `==1` is not the correct recursion base case. Therefore, the previous answer's code examples does not actually give the correct answer for multiplying it's input numbers. I made some edits to my answer to clarify. – micarlise Aug 19 '18 at 23:07