I am attempting to implement the Schonhage-Strassen multiplication algorithm using NTT, and am running into a problem where the final resulting vector is not actually equal to what it should be.
For two input vectors a
, and b
, each consisting of N
"digits" of K
bits (the final N/2
entries of each set to 0), each, given a modulus M = 2^(2*K)+1
, a root of unity w = N^(4*K-1) | w^N = 1 mod M
, a modular inverse of this value wi | wi*w = 1 mod M
, and u | u*N = 1 mod M
, the following python code is used to (attempt to) multiply these vectors using the Schonhage-Strassen algorithm:
#a and b are lists of length N, representing large integers
A = [ sum([ (a[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of a
B = [ sum([ (b[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of b
C = [ (A[i]*B[i])%M for i in range(N)] #A * B multiplied pointwise
c = [ sum([ (C[i]*pow(wi,i*j,M))%M for i in range(N)]) for j in range(N)] #intermediate step in INTT of C
ci = [ (i*u)%M for i in c] #INTT of C, should be product of a and b
In theory, taking the NTT of a
and b
, multiplying pointwise, then taking the INTT of the result should give the product, if I am not mistaken, and I have tested these methods for NTT and INTT to confirm that they are inverses of each other. However, the final resulting vector ci
, rather than being equal to the product of a
and b
, is the product where each element is taken modulo M
, giving an incorrect result for the product.
For example, running a test with N=K=8
, and random vectors for a, b
, gives the following:
M = 2^(2*8)+1 = 65537
w = 16, wi = 61441
u = 57345
a = [212, 251, 84, 186, 0, 0, 0, 0] (3126131668 as an integer)
b = [180, 27, 234, 225, 0, 0, 0, 0] (3790216116)
NTT(a) = [733, 66681, 147842, 92262, 130933, 107825, 114562, 127302]
NTT(b) = [666, 64598, 80332, 54468, 131236, 186644, 181708, 88232]
Pointwise product of above two lines mod M = [29419, 39913, 25015, 14993, 42695, 49488, 52438, 51319]
INTT of above line (i.e. result) = [38160, 50904, 5968, 11108, 15616, 62424, 41850, 0] (11848430946168040720)
Actual product of a x b = [38160, 50904, 71505, 142182, 81153, 62424, 41850, 0] (11848714628791561488)
In this example, and in pretty much every time I try it, the elements of the actual product and the result of my algorithm are the same for several elements near the beginning and end of the vector, but towards the middle they deviate. As I mentioned above, the elements of ci
are each equal to the elements of a*b
modulo M
. I must be misunderstanding something about this algorithm, though I'm not entirely sure what. Am I using the wrong modulus somewhere?