diff --git a/include/int2048.h b/include/int2048.h index 402e249..ce94023 100644 --- a/include/int2048.h +++ b/include/int2048.h @@ -15,7 +15,7 @@ class int2048 { * num_length is the length of the integer, (num_length+kNum-1)/kNum is the * length of val with data. Note that position in val without data is 0. */ - const static int kMod = 100000000, kNum = 8, kDefaultLength = 10; + const static int kStoreBase = 100000000, kNum = 8, kDefaultLength = 10; const static int kMemAdditionScalar = 2, kMemDeleteScalar = 4; /** * the follow data used by NTT is generated by this code: @@ -50,7 +50,8 @@ root= 6 void NTTTransform(__int128_t *, int, bool); void RightMoveBy(int); - + void ProcessHalfBlock(); + void RestoreHalfBlock(); public: int2048(); int2048(long long); diff --git a/src/int2048.cpp b/src/int2048.cpp index f2b1a42..17a9bf4 100644 --- a/src/int2048.cpp +++ b/src/int2048.cpp @@ -186,8 +186,8 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB, i++) { if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum) A.val[i] += pB->val[i]; - if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kMod; - A.val[i] %= int2048::kMod; + if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kStoreBase; + A.val[i] %= int2048::kStoreBase; } } else { for (int i = (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) / @@ -196,9 +196,9 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB, i >= 0; i--) { if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum) A.val[i] += pB->val[i]; - if (A.val[i] >= int2048::kMod && i - 1 >= 0) { - A.val[i - 1] += A.val[i] / int2048::kMod; - A.val[i] %= int2048::kMod; + if (A.val[i] >= int2048::kStoreBase && i - 1 >= 0) { + A.val[i - 1] += A.val[i] / int2048::kStoreBase; + A.val[i] %= int2048::kStoreBase; } } } @@ -261,20 +261,22 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) { i++) { A.val[i] -= pB->val[i]; if (A.val[i] < 0 && i + 1 < A.buf_length) { - A.val[i] += int2048::kMod; + A.val[i] += int2048::kStoreBase; A.val[i + 1]--; } } } else { int blocks_A = (A.num_length + int2048::kNum - 1) / int2048::kNum; int blocks_B = (pB->num_length + int2048::kNum - 1) / int2048::kNum; - if (blocks_A < blocks_B) A.ClaimMem(blocks_A * int2048::kNum); - blocks_A = (A.num_length + int2048::kNum - 1) / int2048::kNum; + if (blocks_A < blocks_B) { + A.ClaimMem(blocks_B * int2048::kNum); + blocks_A = blocks_B; + } for (int i = (pB->num_length + int2048::kNum - 1) / int2048::kNum - 1; i >= 0; i--) { if (i < blocks_B && i < blocks_A) A.val[i] -= pB->val[i]; if (i < blocks_A && A.val[i] < 0 && i - 1 >= 0) { - A.val[i] += int2048::kMod; + A.val[i] += int2048::kStoreBase; A.val[i - 1]--; } } @@ -394,6 +396,63 @@ __int128_t int2048::QuickPow(__int128_t v, long long q) { } return ret; } + +// /** +// * @brief Move the number to the left by L digits. That is, v'=v*(10^L) +// */ +// void int2048::LeftMoveBy(int L) { +// const static int kPow10[9] = {1, 10, 100, 1000, 10000, +// 100000, 1000000, 10000000, 100000000}; +// int big_move = L / int2048::kNum; +// int small_move = L % int2048::kNum; +// this->ClaimMem(this->num_length + L); +// for (int i = this->buf_length - 1; i >= big_move; i--) { +// this->val[i] = this->val[i - big_move]; +// } +// for (int i = 0; i < big_move; i++) { +// this->val[i] = 0; +// } +// this->num_length += big_move * int2048::kNum; +// if (small_move == 0) return; +// for (int i = this->buf_length - 1; i >= 0; i--) { +// (this->val[i] *= kPow10[small_move]) %= int2048::kStoreBase; +// if (i - 1 >= 0) { +// this->val[i] += this->val[i - 1] / kPow10[int2048::kNum - small_move]; +// } +// } +// } + +/** + * @brief Move the number to the right by L digits. That is, v'=v//(10^L) + */ +void int2048::RightMoveBy(int L) { + if (L >= this->num_length) { + this->num_length = 1; + this->val[0] = 0; + return; + } + int big_move = L / int2048::kNum; + int small_move = L % int2048::kNum; + for (int i = 0; i < this->buf_length - big_move; i++) { + this->val[i] = this->val[i + big_move]; + } + for (int i = this->buf_length - big_move; i < this->buf_length; i++) { + this->val[i] = 0; + } + this->num_length -= big_move * int2048::kNum; + if (small_move == 0) return; + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + for (int i = 0; i < this->buf_length; i++) { + this->val[i] /= kPow10[small_move]; + if (i + 1 < this->buf_length) { + this->val[i] += this->val[i + 1] % kPow10[small_move] * + kPow10[int2048::kNum - small_move]; + } + } + this->num_length -= small_move; +} + void int2048::NTTTransform(__int128_t *a, int NTT_blocks, bool inverse = false) { for (int i = 1, j = 0; i < NTT_blocks; i++) { @@ -434,13 +493,26 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB, __int128_t *pDA = new __int128_t[NTT_blocks](); __int128_t *pDB = new __int128_t[NTT_blocks](); __int128_t *pDC = new __int128_t[NTT_blocks](); - for (int i = 0; i < blocks_of_A; i++) { - pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase; - pDA[(i << 1) | 1] = A.val[i] / int2048::kNTTBlockBase; - } - for (int i = 0; i < blocks_of_B; i++) { - pDB[i << 1] = pB->val[i] % int2048::kNTTBlockBase; - pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase; + if (!inverse) { + for (int i = 0; i < blocks_of_A; i++) { + pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase; + pDA[(i << 1) | 1] = A.val[i] / int2048::kNTTBlockBase; + } + for (int i = 0; i < blocks_of_B; i++) { + pDB[i << 1] = pB->val[i] % int2048::kNTTBlockBase; + pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase; + } + } else { + pDA[0] = A.val[0]; + for (int i = 1; i < blocks_of_A; i++) { + pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase; + pDA[(i << 1) - 1] = A.val[i] / int2048::kNTTBlockBase; + } + pDB[0] = pB->val[0]; + for (int i = 1; i < blocks_of_B; i++) { + pDB[i << 1] = pB->val[i] % int2048::kNTTBlockBase; + pDB[(i << 1) - 1] = pB->val[i] / int2048::kNTTBlockBase; + } } A.NTTTransform(pDA, NTT_blocks); A.NTTTransform(pDB, NTT_blocks); @@ -465,8 +537,15 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB, int flag_store = A.flag; A.ClaimMem(NTT_blocks * 4); memset(A.val, 0, A.buf_length * sizeof(int)); - for (int i = 0; i < NTT_blocks / 2; i++) { - A.val[i] = pDC[(i << 1) | 1] * int2048::kNTTBlockBase + pDC[i << 1]; + if (!inverse) { + for (int i = 0; i < NTT_blocks / 2; i++) { + A.val[i] = pDC[(i << 1) | 1] * int2048::kNTTBlockBase + pDC[i << 1]; + } + } else { + A.val[0] = pDC[0]; + for (int i = 1; i < NTT_blocks / 2; i++) { + A.val[i] = pDC[(i << 1) - 1] * int2048::kNTTBlockBase + pDC[i << 1]; + } } A.num_length = NTT_blocks * 4; const static int kPow10[9] = {1, 10, 100, 1000, 10000, @@ -516,35 +595,24 @@ int2048 operator*(int2048 A, const int2048 &B) { A.Multiply(B); return std::move(A); } - -void int2048::RightMoveBy(int L) { - if (L >= this->num_length) { - this->num_length = 1; - this->val[0] = 0; - return; +void int2048::ProcessHalfBlock() { + this->ClaimMem(this->num_length + int2048::kNTTBlockBase); + int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; + for (int i = blocks_num - 1; i >= 1; i--) { + val[i] /= int2048::kNTTBlockBase; + val[i] += (val[i - 1] % int2048::kNTTBlockBase) * int2048::kNTTBlockBase; } - int big_move = L / int2048::kNum; - int small_move = L % int2048::kNum; - for (int i = 0; i < this->buf_length - big_move; i++) { - this->val[i] = this->val[i + big_move]; - } - for (int i = this->buf_length - big_move; i < this->buf_length; i++) { - this->val[i] = 0; - } - this->num_length -= big_move * int2048::kNum; - if (small_move == 0) return; - const static int kPow10[9] = {1, 10, 100, 1000, 10000, - 100000, 1000000, 10000000, 100000000}; - for (int i = 0; i < this->buf_length; i++) { - this->val[i] /= kPow10[small_move]; - if (i + 1 < this->buf_length) { - this->val[i] += this->val[i + 1] % kPow10[small_move] * - kPow10[int2048::kNum - small_move]; - } - } - this->num_length -= small_move; + val[0] /= int2048::kNTTBlockBase; +} +void int2048::RestoreHalfBlock() { + int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; + for (int i = 0; i < blocks_num - 1; i++) { + val[i] *= int2048::kNTTBlockBase; + val[i] %= int2048::kStoreBase; + val[i] += val[i + 1] / int2048::kNTTBlockBase; + } + (val[blocks_num - 1] *= int2048::kNTTBlockBase) %= int2048::kStoreBase; } - inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int L1 = A.num_length, L2 = pB->num_length; if (&A == pB) throw "UnsignedDivide: A and B are the same object"; @@ -572,8 +640,10 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int2048 inverse_B(*pB); for (int i = 0; (i << 1) < (pow_B + 1); i++) std::swap(inverse_B.val[i], inverse_B.val[pow_B - i]); - int2048 x(int2048::kMod); - assert(x.val[1] == 1); + int2048 x( + int2048::kStoreBase * + (long long)std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); + assert(x.val[1] == std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); int *store[2]; store[0] = new int[pow_A + 5](); store[1] = new int[pow_A + 5](); @@ -582,11 +652,41 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { store[0][i] = A.val[i]; store[1][i] = -1; } + int inverseB_error = 0; + if (inverse_B.val[0] >= int2048::kNTTBlockBase) { + inverseB_error = 1; + inverse_B.ProcessHalfBlock(); + } while (true) { - int2048 invsere_two(2), tmp_x(x); + int2048 inverse_two(2), tmp_x(x); + int tmp_x_error = 0; + if (tmp_x.val[0] >= int2048::kNTTBlockBase) { + tmp_x_error = 1; + tmp_x.ProcessHalfBlock(); + } UnsignedMultiply(tmp_x, &inverse_B, true); - UnsignedMinus(invsere_two, &tmp_x, true); - UnsignedMultiply(x, &invsere_two, true); + tmp_x.num_length = + ((tmp_x.num_length + int2048::kNum - 1) / int2048::kNum) * + int2048::kNum; + for (int i = 0; i < tmp_x_error + inverseB_error; i++) + tmp_x.RestoreHalfBlock(); + UnsignedMinus(inverse_two, &tmp_x, true); + inverse_two.num_length = + ((inverse_two.num_length + int2048::kNum - 1) / int2048::kNum) * + int2048::kNum; + int inverse_two_error = 0, x_error = 0; + if (inverse_two.val[0] >= int2048::kNTTBlockBase) { + inverse_two_error = 1; + inverse_two.ProcessHalfBlock(); + } + if (x.val[0] >= int2048::kNTTBlockBase) { + x_error = 1; + x.ProcessHalfBlock(); + } + UnsignedMultiply(x, &inverse_two, true); + x.num_length = + ((x.num_length + int2048::kNum - 1) / int2048::kNum) * int2048::kNum; + for (int i = 0; i < x_error + inverse_two_error; i++) x.RestoreHalfBlock(); /** * now x is the next x, store[tot] stores last x, store[tot^1] stores the x * previous to store[x] @@ -617,6 +717,9 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { else store[tot][i] = 0; } + fprintf(stderr, "x: "); + for (int i = 0; i < blocks_of_x; i++) fprintf(stderr, "%08d ", x.val[i]); + fprintf(stderr, "\n"); } delete[] store[0]; delete[] store[1];