There are several issues:
The exception you got is caused by infinite recursion at this call:
kara_mul(a + b, c + d)
As these variables are strings, the +
is a string concatenation. This means these arguments evaluate to
n
and m
, which were the arguments to the current execution of the function.
The correct algorithm would perform a numerical addition here, for which you need to provide an implementation (adding two string representations of potentially very long integers)
if (len_n == 1 && len_m == 1)
detects the base case, but the base case should kick in when either of these sizes is 1, not necessary both. So this should be an ||
operator, or should be written as two separate if
statements.
The input strings should be split such that b
and d
are equal in size. This is not what your code does. Note how the Wikipedia article stresses this point:
The second argument of the split_at function specifies the number of digits to extract from the right
stol
should never be called on strings that could potentially be too long for conversion to long
. So for example, stol(p1)
is not safe, as p1
could have 20 or more digits.
As a consequence of the previous point, you'll need to implement functions that add or subtract two string representations of numbers, and also one that can multiply a string representation with a single digit (the base case).
Here is an implementation that corrects these issues:
#include <iostream>
#include <algorithm>
int digit(std::string n, int i) {
return i >= n.size() ? 0 : n[n.size() - i - 1] - '0';
}
std::string add(std::string n, std::string m) {
int len = std::max(n.size(), m.size());
std::string result;
int carry = 0;
for (int i = 0; i < len; i++) {
int sum = digit(n, i) + digit(m, i) + carry;
result += (char) (sum % 10 + '0');
carry = sum >= 10;
}
if (carry) result += '1';
reverse(result.begin(), result.end());
return result;
}
std::string subtract(std::string n, std::string m) {
int len = n.size();
if (m.size() > len) throw std::invalid_argument("subtraction overflow");
if (n == m) return "0";
std::string result;
int carry = 0;
for (int i = 0; i < len; i++) {
int diff = digit(n, i) - digit(m, i) - carry;
carry = diff < 0;
result += (char) (diff + carry * 10 + '0');
}
if (carry) throw std::invalid_argument("subtraction overflow");
result.erase(result.find_last_not_of('0') + 1);
reverse(result.begin(), result.end());
return result;
}
std::string simple_mul(std::string n, int coefficient) {
if (coefficient < 2) return coefficient ? n : "0";
std::string result = simple_mul(add(n, n), coefficient / 2);
return coefficient % 2 ? add(result, n) : result;
}
std::string kara_mul(std::string n, std::string m) {
int len_n = n.size();
int len_m = m.size();
if (len_n == 1) return simple_mul(m, digit(n, 0));
if (len_m == 1) return simple_mul(n, digit(m, 0));
int len_min2 = std::min(len_n, len_m) / 2;
std::string a = n.substr(0, len_n - len_min2);
std::string b = n.substr(len_n - len_min2);
std::string c = m.substr(0, len_m - len_min2);
std::string d = m.substr(len_m - len_min2);
std::string p1 = kara_mul(a, c);
std::string p2 = kara_mul(b, d);
std::string p3 = subtract(kara_mul(add(a, b), add(c, d)), add(p1, p2));
return add(add(p1 + std::string(len_min2*2, '0'), p2), p3 + std::string(len_min2, '0'));
}