diff --git a/include/int2048.h b/include/int2048.h index c83c521..aa30e6d 100644 --- a/include/int2048.h +++ b/include/int2048.h @@ -15,14 +15,39 @@ class int2048 { * num_length is the length of the integer, (num_length+kNum-1)/kNum is the * length of val with data. Note that position in val without data is 0. */ - size_t buf_length = 0; const static int kMod = 100000000, kNum = 8, kDefaultLength = 10; const static int kMemAdditionScalar = 2, kMemDeleteScalar = 4; + /** + * the follow data used by NTT is generated by this code: +#!/usr/bin/python3 +from sympy import isprime,primitive_root +found=False +for i in range(0,20): + for j in range(2**i,(2**(i+1))): + V=j*(2**(57-i))+1 + if isprime(V): + found=True + print(j,57-i) + print("number=",V) + print("root=",primitive_root(V)) + exit(0) + * it out puts: +95 55 +number= 180143985094819841 +root= 6 + */ + const static __int128_t kNTTMod = 180143985094819841ll; + const static __int128_t kNTTRoot = 6; + const static int kNTTBlockNum = 4; + const static int kNTTBlcokBase = 10000; + + size_t buf_length = 0; int *val = nullptr; signed char flag = +1; int num_length = 0; - void NTT(__int128_t *, int, int); + __int128_t QuickPow(__int128_t v, long long q); + void NTTTransform(__int128_t *, int, bool); public: int2048(); diff --git a/src/int2048.cpp b/src/int2048.cpp index 252a436..49ec719 100644 --- a/src/int2048.cpp +++ b/src/int2048.cpp @@ -26,6 +26,9 @@ #include #include #include + +static_assert(sizeof(int) == 4, "sizeof(int) != 4"); +static_assert(sizeof(long long) == 8, "sizeof(long long)!=8"); namespace sjtu { // 构造函数 int2048::int2048() { @@ -155,7 +158,7 @@ void int2048::ClaimMem(size_t number_length) { inline int UnsignedCmp(const int2048 &A, const int2048 &B) { if (A.num_length != B.num_length) return A.num_length < B.num_length ? -1 : 1; - int number_of_blocks = (A.num_length + A.kNum - 1) / A.kNum; + int number_of_blocks = (A.num_length + int2048::kNum - 1) / int2048::kNum; for (int i = number_of_blocks - 1; i >= 0; i--) if (A.val[i] != B.val[i]) return A.val[i] < B.val[i] ? -1 : 1; return 0; @@ -165,16 +168,20 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB) { if (&A == pB) throw "UnsignedAdd: A and B are the same object"; A.ClaimMem(std::max(A.num_length, pB->num_length) + 2); for (int i = 0; - i < (std::max(A.num_length, pB->num_length) + A.kNum - 1) / A.kNum; + i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) / + int2048::kNum; i++) { - if (i < (pB->num_length + pB->kNum - 1) / pB->kNum) A.val[i] += pB->val[i]; - A.val[i + 1] += A.val[i] / A.kMod; - A.val[i] %= A.kMod; + if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum) + A.val[i] += pB->val[i]; + A.val[i + 1] += A.val[i] / int2048::kMod; + A.val[i] %= int2048::kMod; } A.num_length = std::max(A.num_length, pB->num_length); const static int kPow10[9] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000}; - if (A.val[A.num_length / A.kNum] / kPow10[A.num_length % A.kNum] > 0) + if (A.val[A.num_length / int2048::kNum] / + kPow10[A.num_length % int2048::kNum] > + 0) A.num_length++; } @@ -222,10 +229,11 @@ int2048 add(int2048 A, const int2048 &B) { inline void UnsignedMinus(int2048 &A, const int2048 *const pB) { if (&A == pB) throw "UnsignedMinus: A and B are the same object"; - for (int i = 0; i < (pB->num_length + A.kNum - 1) / A.kNum; i++) { + for (int i = 0; i < (pB->num_length + int2048::kNum - 1) / int2048::kNum; + i++) { A.val[i] -= pB->val[i]; if (A.val[i] < 0) { - A.val[i] += A.kMod; + A.val[i] += int2048::kMod; A.val[i + 1]--; } } @@ -233,7 +241,8 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB) { 100000, 1000000, 10000000, 100000000}; int new_length = 0; for (int i = 0; i < A.num_length; i++) - if (A.val[i / A.kNum] / kPow10[i % A.kNum] > 0) new_length = i + 1; + if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0) + new_length = i + 1; A.num_length = new_length; if (A.num_length == 0) A.num_length = 1; A.ClaimMem(A.num_length); @@ -329,10 +338,92 @@ int2048 operator-(int2048 A, const int2048 &B) { A.minus(B); return std::move(A); } - -inline void UnsignedMultiply(int2048 &A, const int2048 *pB) -{ - ; +__int128_t int2048::QuickPow(__int128_t v, long long q) { + __int128_t ret = 1; + v %= int2048::kNTTMod; + while (q > 0) { + if (q & 1) (ret *= v) %= int2048::kNTTMod; + (v *= v) %= int2048::kNTTMod; + q >>= 1; + } + return ret; +} +void int2048::NTTTransform(__int128_t *a, int NTT_blocks, + bool inverse = false) { + for (int i = 1, j = 0; i < NTT_blocks; i++) { + int bit = NTT_blocks >> 1; + while (j >= bit) { + j -= bit; + bit >>= 1; + } + j += bit; + if (i < j) std::swap(a[i], a[j]); + } + for (int len = 2; len <= NTT_blocks; len <<= 1) { + __int128_t wlen = QuickPow(int2048::kNTTRoot, (int2048::kNTTMod - 1) / len); + if (inverse) wlen = QuickPow(wlen, int2048::kNTTMod - 2); + for (int i = 0; i < NTT_blocks; i += len) { + __int128_t w = 1; + for (int j = 0; j < len / 2; j++) { + __int128_t u = a[i + j], v = a[i + j + len / 2] * w % int2048::kNTTMod; + a[i + j] = (u + v) % int2048::kNTTMod; + a[i + j + len / 2] = (u - v + int2048::kNTTMod) % int2048::kNTTMod; + (w *= wlen) %= int2048::kNTTMod; + } + } + } + if (inverse) { + __int128_t inv = QuickPow(NTT_blocks, int2048::kNTTMod - 2); + for (int i = 0; i < NTT_blocks; i++) (a[i] *= inv) %= int2048::kNTTMod; + } +} +inline void UnsignedMultiply(int2048 &A, const int2048 *pB) { + if (&A == pB) throw "UnsignedMultiply: A and B are the same object"; + int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum); + int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum); + int max_blocks = blocks_of_A + blocks_of_B; + int NTT_blocks = 1; + while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1; + __int128_t *pDA = new __int128_t[NTT_blocks](); + __int128_t *pDB = new __int128_t[NTT_blocks](); + __int128_t *pDC = new __int128_t[NTT_blocks](); + for (int i = 0; i < blocks_of_A; i++) { + pDA[i << 1] = A.val[i] % int2048::kNTTBlcokBase; + pDA[(i << 1) | 1] = A.val[i] / int2048::kNTTBlcokBase; + } + for (int i = 0; i < blocks_of_B; i++) { + pDB[i << 1] = pB->val[i] % int2048::kNTTBlcokBase; + pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlcokBase; + } + A.NTTTransform(pDA, NTT_blocks); + A.NTTTransform(pDB, NTT_blocks); + for (int i = 0; i < NTT_blocks; i++) + pDC[i] = (pDA[i] * pDB[i]) % int2048::kNTTMod; + A.NTTTransform(pDC, NTT_blocks, true); + for (int i = 0; i < NTT_blocks - 1; i++) { + pDC[i + 1] += pDC[i] / int2048::kNTTBlcokBase; + pDC[i] %= int2048::kNTTBlcokBase; + } + if (pDC[NTT_blocks - 1] >= int2048::kNTTBlcokBase) + throw "UnsignedMultiply: NTT result overflow"; + int flag_store = A.flag; + A.ClaimMem(NTT_blocks * 4); + memset(A.val, 0, A.buf_length * sizeof(int)); + for (int i = 0; i < NTT_blocks / 2; i++) { + A.val[i] = pDC[(i << 1) | 1] * int2048::kNTTBlcokBase + pDC[i << 1]; + } + A.num_length = NTT_blocks * 4; + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + while (A.val[(A.num_length - 1) / int2048::kNum] / + kPow10[(A.num_length - 1) % int2048::kNum] == + 0) { + A.num_length--; + if (A.num_length == 0) throw "UnsignedMultiply: num_length==0"; + } + delete[] pDA; + delete[] pDB; + delete[] pDC; } int2048 &int2048::Multiply(const int2048 &B) { @@ -345,7 +436,7 @@ int2048 &int2048::Multiply(const int2048 &B) { return *this; } this->flag = this->flag * pB->flag; - UnsignedMultiply(*this,pB); + UnsignedMultiply(*this, pB); return *this; } @@ -395,7 +486,8 @@ std::ostream &operator<<(std::ostream &stream, const int2048 &v) { const static int kPow10[9] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000}; for (int i = v.num_length - 1; i >= 0; i--) - stream << char('0' + v.val[i / v.kNum] / kPow10[i % v.kNum] % 10); + stream << char('0' + + v.val[i / int2048::kNum] / kPow10[i % int2048::kNum] % 10); return stream; }