In the code that follows I implement addition and subtraction using the same idea as in the question. The only practical difference is that in my implementation these two functions also take in a carry-in/borrow-in bit and produce a carry-out/borrow-out bit.
The carry-in bit is used to implement subtraction via addition and this bit helps to get correct values of the carry-out and borrow-out bits. Basically, I implement typical CPU-like addition and subtraction with the carry flag in the status register.
The carry/borrow bits are then used to implement comparison via subtraction. I implement comparison without the >=
operator, which I also consider arithmetic, because of its not quite bit-wise nature. The comparison function is needed in the division function because of using the restoring division algorithm.
I also avoid using the !
operator and use ^1
instead.
The division function takes the divisor as 2 unsigned ints
, the most- and the least-significant parts of it. At the end it replaces the most-significant part with the remainder and the least-significant part with the quotient. So, it does both division and modulo and does them in a typical CPU-like way (e.g. like the x86 DIV
instruction). The function returns 1 on success and 0 on overflow/division by 0.
The main function does a simple test. It compares the results from the division function against the results of direct division and terminates with an error message on a mismatch.
I use unsigned long long
in the test part to be able to test divisor=UINT_MAX
without falling into an infinite loop. It may take too much time to test the entire range of values of the dividend and the divisor, which is why I cap them at 0xFFFF and 0xFF respectively instead of at UINT_MAX
.
Code:
#include <stdio.h>
#include <limits.h>
unsigned add(unsigned a, unsigned b, unsigned carryIn, unsigned* carryOut)
{
unsigned sum = a ^ b ^ carryIn;
unsigned carryOuts = a & b | (a | b) & carryIn;
*carryOut = 0;
if (sum & (carryOuts << 1))
sum = add(sum, carryOuts << 1, 0, carryOut);
else
sum |= carryOuts << 1;
*carryOut |= (carryOuts & (UINT_MAX / 2 + 1)) >> (sizeof(unsigned) * CHAR_BIT - 1); // +-*/ are OK in constants
return sum;
}
unsigned sub(unsigned a, unsigned b, unsigned borrowIn, unsigned* borrowOut)
{
unsigned diff = add(a, ~b, borrowIn ^ 1, borrowOut);
*borrowOut ^= 1;
return diff;
}
unsigned less(unsigned a, unsigned b)
{
unsigned borrowOut;
sub(a, b, 0, &borrowOut);
return borrowOut;
}
int udiv(unsigned* dividendh, unsigned* dividendl, unsigned divisor)
{
int i;
unsigned tmp;
if (less(*dividendh, divisor) ^ 1/* *dividendh >= divisor */)
return 0; // overflow
for (i = 0; i < sizeof(unsigned) * CHAR_BIT; i++)
{
if (less(*dividendh, UINT_MAX / 2 + 1) ^ 1/* *dividendh >= 0x80...00 */)
{
*dividendh = (*dividendh << 1) | (*dividendl >> (sizeof(unsigned) * CHAR_BIT - 1));
*dividendl <<= 1;
*dividendh = sub(*dividendh, divisor, 0, &tmp);/* *dividendh -= divisor; */
*dividendl |= 1;
}
else
{
*dividendh = (*dividendh << 1) | (*dividendl >> (sizeof(unsigned) * CHAR_BIT - 1));
*dividendl <<= 1;
if (less(*dividendh, divisor) ^ 1/* *dividendh >= divisor */)
{
*dividendh = sub(*dividendh, divisor, 0, &tmp);/* *dividendh -= divisor; */
*dividendl |= 1;
}
}
}
return 1;
}
int udiv2(unsigned* dividendh, unsigned* dividendl, unsigned divisor)
{
unsigned long long dividend =
((unsigned long long)*dividendh << (sizeof(unsigned) * CHAR_BIT)) | *dividendl;
if (*dividendh >= divisor)
return 0; // overflow
*dividendl = (unsigned)(dividend / divisor);
*dividendh = (unsigned)(dividend % divisor);
return 1;
}
int main(void)
{
unsigned long long dividend, divisor;
for (dividend = 0; dividend <= /*UINT_MAX*/0xFFFF; dividend++)
for (divisor = 0; divisor <= /*UINT_MAX*/0xFF; divisor++)
{
unsigned divh = 0, divl = (unsigned)dividend, divr = (unsigned)divisor;
unsigned divh2 = 0, divl2 = (unsigned)dividend;
printf("0x%08X/0x%08X=", divl, divr);
if (udiv(&divh, &divl, divr))
printf("0x%08X.0x%08X", divl, divh);
else
printf("ovf");
printf(" ");
if (udiv2(&divh2, &divl2, divr))
printf("0x%08X.0x%08X", divl2, divh2);
else
printf("ovf");
if ((divl != divl2) || (divh != divh2))
{
printf(" err");
return -1;
}
printf("\n");
}
return 0;
}