TL;DR build a carry-lookahead adder where each individual additionner adds modulo radix, instead of modulo 2
Additions need incoming carries
The problem in your model is that you have a rippling carry. See Rippling carry adders.
If you were in an FPGA that wouldn't be a problem because they have dedicated logic to do that fast (carry chains, they're cool). But alas, you're on a GPU !
That is, for a given id
, you only know the input carry (thus whether you are going to sum A[id]+B[id]
or A[id]+B[id]+1
) when all the sums with smaller id
values have been computed. As a matter of fact, initially, you only know the first carry.
A[3]+B[3] + ? A[2]+B[2] + ? A[1]+B[1] + ? A[0]+B[0] + 0
| | | |
v v v v
C[3] C[2] C[1] C[0]
Characterize the carry output
And each sum also has a carry output, which isn't on the drawing. So you have to think of the addition in this larger scheme as a function with 3 inputs and 2 outputs : (C, c_out) = add(A, B, c_in)
In order to not wait O(n) for the sum to complete (where n is the number of items your sum is cut into), you can precompute all the possible results at each id
. That isn't such a huge load of work, since A
and B
don't change, only the carries. So you have 2 possible outputs : (c_out0, C) = add(A, B, 0)
and (c_out1, C') = add(A, B, 1)
.
Now with all these results, we need to basically implement a carry lookahead unit.
For that, we need to figure out to functions of each sum's carry output P
and G
:
P
a.k.a. all of the following definitions
- Propagate
- "if a carry comes in, then a carry will go out of this sum"
c_out1 && !c_out0
A + B == radix-1
G
a.k.a. all of the following definitions
- Generate
- "whatever carry comes in, a carry will go out of this sum"
c_out1 && c_out0
c_out0
A + B >= radix
So in other terms, c_out = G or (P and c_in)
. So now we have a start of an algorithm that can tell us easily for each id the carry output as a function of its carry input directly :
- At each
id
, compute C[id] = A[id]+B[id]+0
- Get
G[id] = C[id] > radix -1
- Get
P[id] = C[id] == radix-1
Logarithmic tree
Now we can finish in O(log(n)), even though treeish things are nasty on GPUs, but still shorter than waiting. Indeed, from 2 additions next to each other, we can get a group G
and a group P
:
For id
and id+1
:
step = 2
if id % step == 0, do steps 6 through 10, otherwise, do nothing
group_P = P[id] and P[id+step/2]
group_G = (P[id+step/2] and G[id]) or G[id+step/2]
c_in[id+step/2] = G[id] or (P[id] and c_in[id])
step = step * 2
if step < n, go to 5
At the end (after repeating steps 5-10 for every level of your tree with less id
s every time), everything will be expressed in terms of P
s and G
s which you computed, and c_in[0]
which is 0
. On the wikipedia page there are formulas for the grouping by 4 instead of 2, which will get you an answer in O(log_4(n)) instead of O(log_2(n)).
Hence the end of the algorithm :
- At each
id
, get c_in[id]
- return
(C[id]+c_in[id]) % radix
Take advantage of hardware
What we really did in this last part, was mimic the circuitry of a carry-lookahead adder with logic. However, we already have additionners in the hardware that do similar things (by definition).
Let us replace our definitions of P
and G
based on radix by those based on 2
like the logic inside our hardware, mimicking a sum of 2 bits a
and b
at each stage : if P = a ^ b
(xor), and G = a & b
(logical and). In other words, a = P or G
and b = G
. So if we create a intP
integer and a intG
integer, where each bit is respectively the P
and G
we computed from each id
s sum (limiting us to 64 sums), then the addition (intP | intG) + intG
has the exact same carry propagation as our elaborate logical scheme.
The reduction to form these integers will still be a logarithmic operation I guess, but that was to be expected.
The interesting part, is that each bit of the sum is function of its carry input. Indeed, every bit of the sum is eventually function of 3 bits a+b+c_in % 2
.
- If at that bit
P == 1
, then a + b == 1
, thus a+b+c_in % 2 == !c_in
- Otherwise,
a+b
is either 0
or 2
, and a+b+c_in % 2 == c_in
Thus we can trivially form the integer (or rather bit-array) int_cin = ((P|G)+G) ^ P
with ^
being xor
.
Thus we have an alternate ending to our algorithm, replacing steps 4 and later :
- at each
id
, shift P
and G
by id
: P = P << id
and G = G << id
- do an OR-reduction to get
intG
and intP
which are the OR
of all the P
and G
for id
0..63
- Compute (once)
int_cin = ((P|G)+G) ^ P
- at each
id
, get `c_in = int_cin & (1 << id) ? 1 : 0;
- return
(C[id]+c_in) % radix
PS : Also, watch out for integer overflow in your arrays, if radix
is big. If it isn't then the whole thing doesn't really make sense I guess...
PPS : in the alternate ending, if you have more than 64 items, characterize them by their P
and G
as if radix
was 2^64
, and re-run the same steps at a higher level (reduction, get c_in
) and then get back to the lower level apply 7
with P+G+carry in from higher level