The key insight is that we can work backwards, finding s
and t
for each a
and b
in the recursion. So say we have a = 21
and b = 15
. We need to work through each iteration, using several values -- a
, b
, b % a
, and c
where a = c * b + a % b
. First, let's consider each step of the basic GCD algorithm:
21 = 1 * 15 + 6
15 = 2 * 6 + 3
6 = 2 * 3 + 0 -> end recursion
So our gcd (g
) is 3. Once we have that, we determine s
and t
for 6 and 3. To do so, we begin with g
, expressing it in terms of (a, b, s, t = 3, 0, 1, -1)
:
3 = 1 * 3 + -1 * 0
Now we want to eliminate the 0 term. From the last line of the basic algorithm, we know that 0 = 6 - 2 * 3:
3 = 1 * 3 + -1 * (6 - 2 * 3)
Simplifying, we get
3 = 1 * 3 + -1 * 6 + 2 * 3
3 = 3 * 3 + -1 * 6
Now we swap the terms:
3 = -1 * 6 + 3 * 3
So we have s == -1
and t == 3
for a = 6
and b = 3
. So given those values of a
and b
, gcd2
should return (3, -1, 3)
.
Now we step back up through the recursion, and we want to eliminate the 3 term. From the next-to-last line of the basic algorithm, we know that 3 = 15 - 2 * 6. Simplifying and swapping again (slowly, so that you can see the steps clearly...):
3 = -1 * 6 + 3 * (15 - 2 * 6)
3 = -1 * 6 + 3 * 15 - 6 * 6
3 = -7 * 6 + 3 * 15
3 = 3 * 15 + -7 * 6
So for this level of recursion, we return (3, 3, -7)
. Now we want to eliminate the 6 term.
3 = 3 * 15 + -7 * (21 - 1 * 15)
3 = 3 * 15 + 7 * 15 - 7 * 21
3 = 10 * 15 - 7 * 21
3 = -7 * 21 + 10 * 15
And voila, we have calculated s
and t
for 21 and 15.
So schematically, the recursive function will look like this:
def gcd2(a, b):
if (0 == a % b):
# calculate s and t
return b, s, t
else:
g, s, t = gcd2(b, a % b)
# calculate new_s and new_t
return g, new_s, new_t
Note that for our purposes here, using a slightly different base case simplifies things:
def gcd2(a, b):
if (0 == b):
return a, 1, -1
else:
g, s, t = gcd2(b, a % b)
# calculate new_s and new_t
return g, new_s, new_t