From bcec853fe6004d7c430cb7e05fa31f66c199cdeb Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Tue, 31 Oct 2023 16:35:28 +0800 Subject: [PATCH] upd: fix some bug --- src/int2048.cpp | 122 ++++++++++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 45 deletions(-) diff --git a/src/int2048.cpp b/src/int2048.cpp index 17a9bf4..51ac867 100644 --- a/src/int2048.cpp +++ b/src/int2048.cpp @@ -149,6 +149,12 @@ void int2048::print() { delete[] buf; } +/** + * @brief Claim memory for the number. + * + * @details warning: ClaimMem doesn't change num_length, so you should change it + * manually. + */ void int2048::ClaimMem(size_t number_length) { size_t new_number_blocks = (number_length + kNum - 1) / kNum; if (new_number_blocks > buf_length) { @@ -178,8 +184,8 @@ inline void UnsignedMinus(int2048 &, const int2048 *, bool inverse = false); inline void UnsignedAdd(int2048 &A, const int2048 *const pB, bool inverse = false) { if (&A == pB) throw "UnsignedAdd: A and B are the same object"; - A.ClaimMem(std::max(A.num_length, pB->num_length) + 2); if (!inverse) { + A.ClaimMem(std::max(A.num_length, pB->num_length) + 2); for (int i = 0; i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) / int2048::kNum; @@ -189,26 +195,31 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB, if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kStoreBase; A.val[i] %= int2048::kStoreBase; } + A.num_length = std::max(A.num_length, pB->num_length); + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + if (A.val[A.num_length / int2048::kNum] / + kPow10[A.num_length % int2048::kNum] > + 0) + A.num_length++; } else { - for (int i = (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) / - int2048::kNum - - 1; + assert(("this code shouldn't be executed", 0)); + assert(A.num_length % int2048::kNum == 0); + assert(pB->num_length % int2048::kNum == 0); + A.ClaimMem(std::max(A.num_length, pB->num_length)); + A.num_length = std::max(A.num_length, pB->num_length); + for (int i = std::max(A.num_length, pB->num_length) / int2048::kNum - 1; i >= 0; i--) { - if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum) - A.val[i] += pB->val[i]; + if (i < pB->num_length / int2048::kNum) A.val[i] += pB->val[i]; if (A.val[i] >= int2048::kStoreBase && i - 1 >= 0) { A.val[i - 1] += A.val[i] / int2048::kStoreBase; A.val[i] %= int2048::kStoreBase; } } + while (A.num_length > int2048::kNum && + A.val[A.num_length / int2048::kNum - 1] == 0) + A.num_length -= int2048::kNum; } - A.num_length = std::max(A.num_length, pB->num_length); - const static int kPow10[9] = {1, 10, 100, 1000, 10000, - 100000, 1000000, 10000000, 100000000}; - if (A.val[A.num_length / int2048::kNum] / - kPow10[A.num_length % int2048::kNum] > - 0) - A.num_length++; } // 加上一个大整数 @@ -265,11 +276,23 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) { A.val[i + 1]--; } } + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + int new_length = 0; + for (int i = 0; i < A.num_length; i++) + if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0) + new_length = i + 1; + A.num_length = new_length; + if (A.num_length == 0) A.num_length = 1; + A.ClaimMem(A.num_length); } else { - int blocks_A = (A.num_length + int2048::kNum - 1) / int2048::kNum; - int blocks_B = (pB->num_length + int2048::kNum - 1) / int2048::kNum; + assert(A.num_length % int2048::kNum == 0); + assert(pB->num_length % int2048::kNum == 0); + int blocks_A = A.num_length / int2048::kNum; + int blocks_B = pB->num_length / int2048::kNum; if (blocks_A < blocks_B) { A.ClaimMem(blocks_B * int2048::kNum); + A.num_length = blocks_B * int2048::kNum; blocks_A = blocks_B; } for (int i = (pB->num_length + int2048::kNum - 1) / int2048::kNum - 1; @@ -280,16 +303,11 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) { A.val[i - 1]--; } } + while (A.num_length > int2048::kNum && + A.val[A.num_length / int2048::kNum - 1] == 0) + A.num_length -= int2048::kNum; + A.ClaimMem(A.num_length); } - const static int kPow10[9] = {1, 10, 100, 1000, 10000, - 100000, 1000000, 10000000, 100000000}; - int new_length = 0; - for (int i = 0; i < A.num_length; i++) - if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0) - new_length = i + 1; - A.num_length = new_length; - if (A.num_length == 0) A.num_length = 1; - A.ClaimMem(A.num_length); } // 减去一个大整数 @@ -488,7 +506,7 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB, int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum); int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum); int max_blocks = blocks_of_A + blocks_of_B; - int NTT_blocks = 1; + int NTT_blocks = 2; while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1; __int128_t *pDA = new __int128_t[NTT_blocks](); __int128_t *pDB = new __int128_t[NTT_blocks](); @@ -503,6 +521,8 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB, pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase; } } else { + assert(A.num_length % int2048::kNum == 0); + assert(pB->num_length % int2048::kNum == 0); pDA[0] = A.val[0]; for (int i = 1; i < blocks_of_A; i++) { pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase; @@ -548,16 +568,23 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB, } } A.num_length = NTT_blocks * 4; - const static int kPow10[9] = {1, 10, 100, 1000, 10000, - 100000, 1000000, 10000000, 100000000}; - while (A.val[(A.num_length - 1) / int2048::kNum] / - kPow10[(A.num_length - 1) % int2048::kNum] == - 0) { - A.num_length--; - if (A.num_length == 0) { - A.num_length = 1; - break; + if (!inverse) { + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + while (A.val[(A.num_length - 1) / int2048::kNum] / + kPow10[(A.num_length - 1) % int2048::kNum] == + 0) { + A.num_length--; + if (A.num_length == 0) { + A.num_length = 1; + break; + } } + } else { + while (A.num_length > int2048::kNum && + A.val[A.num_length / int2048::kNum - 1] == 0) + A.num_length -= int2048::kNum; + A.ClaimMem(A.num_length); } delete[] pDA; delete[] pDB; @@ -596,8 +623,11 @@ int2048 operator*(int2048 A, const int2048 &B) { return std::move(A); } void int2048::ProcessHalfBlock() { - this->ClaimMem(this->num_length + int2048::kNTTBlockBase); - int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; + assert(this->num_length % int2048::kNum == 0); + this->ClaimMem(this->num_length + int2048::kNum); + this->num_length = this->num_length + int2048::kNum; + assert(this->num_length % int2048::kNum == 0); + int blocks_num = this->num_length / int2048::kNum; for (int i = blocks_num - 1; i >= 1; i--) { val[i] /= int2048::kNTTBlockBase; val[i] += (val[i - 1] % int2048::kNTTBlockBase) * int2048::kNTTBlockBase; @@ -605,13 +635,16 @@ void int2048::ProcessHalfBlock() { val[0] /= int2048::kNTTBlockBase; } void int2048::RestoreHalfBlock() { - int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; + assert(this->num_length % int2048::kNum == 0); + int blocks_num = this->num_length / int2048::kNum; for (int i = 0; i < blocks_num - 1; i++) { val[i] *= int2048::kNTTBlockBase; val[i] %= int2048::kStoreBase; val[i] += val[i + 1] / int2048::kNTTBlockBase; } (val[blocks_num - 1] *= int2048::kNTTBlockBase) %= int2048::kStoreBase; + while (this->num_length > 0 && val[this->num_length / int2048::kNum - 1] == 0) + this->num_length -= int2048::kNum; } inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int L1 = A.num_length, L2 = pB->num_length; @@ -638,12 +671,15 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int pow_B = (L2 + int2048::kNum - 1) / int2048::kNum - 1; // pow_B+1 is the number of blocks (with number) of B' int2048 inverse_B(*pB); + inverse_B.num_length = (inverse_B.num_length + int2048::kNum - 1) / + int2048::kNum * int2048::kNum; for (int i = 0; (i << 1) < (pow_B + 1); i++) std::swap(inverse_B.val[i], inverse_B.val[pow_B - i]); int2048 x( int2048::kStoreBase * (long long)std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); assert(x.val[1] == std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); + x.num_length = 2 * int2048::kNum; int *store[2]; store[0] = new int[pow_A + 5](); store[1] = new int[pow_A + 5](); @@ -659,21 +695,18 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { } while (true) { int2048 inverse_two(2), tmp_x(x); + inverse_two.num_length = int2048::kNum; int tmp_x_error = 0; if (tmp_x.val[0] >= int2048::kNTTBlockBase) { tmp_x_error = 1; tmp_x.ProcessHalfBlock(); } + assert(tmp_x.num_length % int2048::kNum == 0); + assert(inverse_B.num_length % int2048::kNum == 0); UnsignedMultiply(tmp_x, &inverse_B, true); - tmp_x.num_length = - ((tmp_x.num_length + int2048::kNum - 1) / int2048::kNum) * - int2048::kNum; for (int i = 0; i < tmp_x_error + inverseB_error; i++) tmp_x.RestoreHalfBlock(); UnsignedMinus(inverse_two, &tmp_x, true); - inverse_two.num_length = - ((inverse_two.num_length + int2048::kNum - 1) / int2048::kNum) * - int2048::kNum; int inverse_two_error = 0, x_error = 0; if (inverse_two.val[0] >= int2048::kNTTBlockBase) { inverse_two_error = 1; @@ -684,8 +717,6 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { x.ProcessHalfBlock(); } UnsignedMultiply(x, &inverse_two, true); - x.num_length = - ((x.num_length + int2048::kNum - 1) / int2048::kNum) * int2048::kNum; for (int i = 0; i < x_error + inverse_two_error; i++) x.RestoreHalfBlock(); /** * now x is the next x, store[tot] stores last x, store[tot^1] stores the x @@ -694,6 +725,7 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int blocks_of_x = (x.num_length + int2048::kNum - 1) / int2048::kNum; if (blocks_of_x > pow_A + 3) { x.ClaimMem((pow_A + 3) * int2048::kNum); + x.num_length = (pow_A + 3) * int2048::kNum; blocks_of_x = pow_A + 3; } bool pre_same = true, pre_pre_same = true;