upd: fix some bug

This commit is contained in:
2023-10-31 16:35:28 +08:00
parent fd6e5e208e
commit bcec853fe6

View File

@ -149,6 +149,12 @@ void int2048::print() {
delete[] buf; delete[] buf;
} }
/**
* @brief Claim memory for the number.
*
* @details warning: ClaimMem doesn't change num_length, so you should change it
* manually.
*/
void int2048::ClaimMem(size_t number_length) { void int2048::ClaimMem(size_t number_length) {
size_t new_number_blocks = (number_length + kNum - 1) / kNum; size_t new_number_blocks = (number_length + kNum - 1) / kNum;
if (new_number_blocks > buf_length) { if (new_number_blocks > buf_length) {
@ -178,8 +184,8 @@ inline void UnsignedMinus(int2048 &, const int2048 *, bool inverse = false);
inline void UnsignedAdd(int2048 &A, const int2048 *const pB, inline void UnsignedAdd(int2048 &A, const int2048 *const pB,
bool inverse = false) { bool inverse = false) {
if (&A == pB) throw "UnsignedAdd: A and B are the same object"; if (&A == pB) throw "UnsignedAdd: A and B are the same object";
A.ClaimMem(std::max(A.num_length, pB->num_length) + 2);
if (!inverse) { if (!inverse) {
A.ClaimMem(std::max(A.num_length, pB->num_length) + 2);
for (int i = 0; for (int i = 0;
i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) / i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) /
int2048::kNum; int2048::kNum;
@ -189,19 +195,6 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB,
if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kStoreBase; if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kStoreBase;
A.val[i] %= int2048::kStoreBase; A.val[i] %= int2048::kStoreBase;
} }
} else {
for (int i = (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) /
int2048::kNum -
1;
i >= 0; i--) {
if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum)
A.val[i] += pB->val[i];
if (A.val[i] >= int2048::kStoreBase && i - 1 >= 0) {
A.val[i - 1] += A.val[i] / int2048::kStoreBase;
A.val[i] %= int2048::kStoreBase;
}
}
}
A.num_length = std::max(A.num_length, pB->num_length); A.num_length = std::max(A.num_length, pB->num_length);
const static int kPow10[9] = {1, 10, 100, 1000, 10000, const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000}; 100000, 1000000, 10000000, 100000000};
@ -209,6 +202,24 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB,
kPow10[A.num_length % int2048::kNum] > kPow10[A.num_length % int2048::kNum] >
0) 0)
A.num_length++; A.num_length++;
} else {
assert(("this code shouldn't be executed", 0));
assert(A.num_length % int2048::kNum == 0);
assert(pB->num_length % int2048::kNum == 0);
A.ClaimMem(std::max(A.num_length, pB->num_length));
A.num_length = std::max(A.num_length, pB->num_length);
for (int i = std::max(A.num_length, pB->num_length) / int2048::kNum - 1;
i >= 0; i--) {
if (i < pB->num_length / int2048::kNum) A.val[i] += pB->val[i];
if (A.val[i] >= int2048::kStoreBase && i - 1 >= 0) {
A.val[i - 1] += A.val[i] / int2048::kStoreBase;
A.val[i] %= int2048::kStoreBase;
}
}
while (A.num_length > int2048::kNum &&
A.val[A.num_length / int2048::kNum - 1] == 0)
A.num_length -= int2048::kNum;
}
} }
// 加上一个大整数 // 加上一个大整数
@ -265,11 +276,23 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) {
A.val[i + 1]--; A.val[i + 1]--;
} }
} }
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
int new_length = 0;
for (int i = 0; i < A.num_length; i++)
if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0)
new_length = i + 1;
A.num_length = new_length;
if (A.num_length == 0) A.num_length = 1;
A.ClaimMem(A.num_length);
} else { } else {
int blocks_A = (A.num_length + int2048::kNum - 1) / int2048::kNum; assert(A.num_length % int2048::kNum == 0);
int blocks_B = (pB->num_length + int2048::kNum - 1) / int2048::kNum; assert(pB->num_length % int2048::kNum == 0);
int blocks_A = A.num_length / int2048::kNum;
int blocks_B = pB->num_length / int2048::kNum;
if (blocks_A < blocks_B) { if (blocks_A < blocks_B) {
A.ClaimMem(blocks_B * int2048::kNum); A.ClaimMem(blocks_B * int2048::kNum);
A.num_length = blocks_B * int2048::kNum;
blocks_A = blocks_B; blocks_A = blocks_B;
} }
for (int i = (pB->num_length + int2048::kNum - 1) / int2048::kNum - 1; for (int i = (pB->num_length + int2048::kNum - 1) / int2048::kNum - 1;
@ -280,16 +303,11 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) {
A.val[i - 1]--; A.val[i - 1]--;
} }
} }
} while (A.num_length > int2048::kNum &&
const static int kPow10[9] = {1, 10, 100, 1000, 10000, A.val[A.num_length / int2048::kNum - 1] == 0)
100000, 1000000, 10000000, 100000000}; A.num_length -= int2048::kNum;
int new_length = 0;
for (int i = 0; i < A.num_length; i++)
if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0)
new_length = i + 1;
A.num_length = new_length;
if (A.num_length == 0) A.num_length = 1;
A.ClaimMem(A.num_length); A.ClaimMem(A.num_length);
}
} }
// 减去一个大整数 // 减去一个大整数
@ -488,7 +506,7 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum); int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum);
int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum); int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum);
int max_blocks = blocks_of_A + blocks_of_B; int max_blocks = blocks_of_A + blocks_of_B;
int NTT_blocks = 1; int NTT_blocks = 2;
while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1; while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1;
__int128_t *pDA = new __int128_t[NTT_blocks](); __int128_t *pDA = new __int128_t[NTT_blocks]();
__int128_t *pDB = new __int128_t[NTT_blocks](); __int128_t *pDB = new __int128_t[NTT_blocks]();
@ -503,6 +521,8 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase; pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase;
} }
} else { } else {
assert(A.num_length % int2048::kNum == 0);
assert(pB->num_length % int2048::kNum == 0);
pDA[0] = A.val[0]; pDA[0] = A.val[0];
for (int i = 1; i < blocks_of_A; i++) { for (int i = 1; i < blocks_of_A; i++) {
pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase; pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase;
@ -548,6 +568,7 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
} }
} }
A.num_length = NTT_blocks * 4; A.num_length = NTT_blocks * 4;
if (!inverse) {
const static int kPow10[9] = {1, 10, 100, 1000, 10000, const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000}; 100000, 1000000, 10000000, 100000000};
while (A.val[(A.num_length - 1) / int2048::kNum] / while (A.val[(A.num_length - 1) / int2048::kNum] /
@ -559,6 +580,12 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
break; break;
} }
} }
} else {
while (A.num_length > int2048::kNum &&
A.val[A.num_length / int2048::kNum - 1] == 0)
A.num_length -= int2048::kNum;
A.ClaimMem(A.num_length);
}
delete[] pDA; delete[] pDA;
delete[] pDB; delete[] pDB;
delete[] pDC; delete[] pDC;
@ -596,8 +623,11 @@ int2048 operator*(int2048 A, const int2048 &B) {
return std::move(A); return std::move(A);
} }
void int2048::ProcessHalfBlock() { void int2048::ProcessHalfBlock() {
this->ClaimMem(this->num_length + int2048::kNTTBlockBase); assert(this->num_length % int2048::kNum == 0);
int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; this->ClaimMem(this->num_length + int2048::kNum);
this->num_length = this->num_length + int2048::kNum;
assert(this->num_length % int2048::kNum == 0);
int blocks_num = this->num_length / int2048::kNum;
for (int i = blocks_num - 1; i >= 1; i--) { for (int i = blocks_num - 1; i >= 1; i--) {
val[i] /= int2048::kNTTBlockBase; val[i] /= int2048::kNTTBlockBase;
val[i] += (val[i - 1] % int2048::kNTTBlockBase) * int2048::kNTTBlockBase; val[i] += (val[i - 1] % int2048::kNTTBlockBase) * int2048::kNTTBlockBase;
@ -605,13 +635,16 @@ void int2048::ProcessHalfBlock() {
val[0] /= int2048::kNTTBlockBase; val[0] /= int2048::kNTTBlockBase;
} }
void int2048::RestoreHalfBlock() { void int2048::RestoreHalfBlock() {
int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum; assert(this->num_length % int2048::kNum == 0);
int blocks_num = this->num_length / int2048::kNum;
for (int i = 0; i < blocks_num - 1; i++) { for (int i = 0; i < blocks_num - 1; i++) {
val[i] *= int2048::kNTTBlockBase; val[i] *= int2048::kNTTBlockBase;
val[i] %= int2048::kStoreBase; val[i] %= int2048::kStoreBase;
val[i] += val[i + 1] / int2048::kNTTBlockBase; val[i] += val[i + 1] / int2048::kNTTBlockBase;
} }
(val[blocks_num - 1] *= int2048::kNTTBlockBase) %= int2048::kStoreBase; (val[blocks_num - 1] *= int2048::kNTTBlockBase) %= int2048::kStoreBase;
while (this->num_length > 0 && val[this->num_length / int2048::kNum - 1] == 0)
this->num_length -= int2048::kNum;
} }
inline void UnsignedDivide(int2048 &A, const int2048 *pB) { inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int L1 = A.num_length, L2 = pB->num_length; int L1 = A.num_length, L2 = pB->num_length;
@ -638,12 +671,15 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int pow_B = (L2 + int2048::kNum - 1) / int2048::kNum - 1; int pow_B = (L2 + int2048::kNum - 1) / int2048::kNum - 1;
// pow_B+1 is the number of blocks (with number) of B' // pow_B+1 is the number of blocks (with number) of B'
int2048 inverse_B(*pB); int2048 inverse_B(*pB);
inverse_B.num_length = (inverse_B.num_length + int2048::kNum - 1) /
int2048::kNum * int2048::kNum;
for (int i = 0; (i << 1) < (pow_B + 1); i++) for (int i = 0; (i << 1) < (pow_B + 1); i++)
std::swap(inverse_B.val[i], inverse_B.val[pow_B - i]); std::swap(inverse_B.val[i], inverse_B.val[pow_B - i]);
int2048 x( int2048 x(
int2048::kStoreBase * int2048::kStoreBase *
(long long)std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); (long long)std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1)));
assert(x.val[1] == std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1))); assert(x.val[1] == std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1)));
x.num_length = 2 * int2048::kNum;
int *store[2]; int *store[2];
store[0] = new int[pow_A + 5](); store[0] = new int[pow_A + 5]();
store[1] = new int[pow_A + 5](); store[1] = new int[pow_A + 5]();
@ -659,21 +695,18 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
} }
while (true) { while (true) {
int2048 inverse_two(2), tmp_x(x); int2048 inverse_two(2), tmp_x(x);
inverse_two.num_length = int2048::kNum;
int tmp_x_error = 0; int tmp_x_error = 0;
if (tmp_x.val[0] >= int2048::kNTTBlockBase) { if (tmp_x.val[0] >= int2048::kNTTBlockBase) {
tmp_x_error = 1; tmp_x_error = 1;
tmp_x.ProcessHalfBlock(); tmp_x.ProcessHalfBlock();
} }
assert(tmp_x.num_length % int2048::kNum == 0);
assert(inverse_B.num_length % int2048::kNum == 0);
UnsignedMultiply(tmp_x, &inverse_B, true); UnsignedMultiply(tmp_x, &inverse_B, true);
tmp_x.num_length =
((tmp_x.num_length + int2048::kNum - 1) / int2048::kNum) *
int2048::kNum;
for (int i = 0; i < tmp_x_error + inverseB_error; i++) for (int i = 0; i < tmp_x_error + inverseB_error; i++)
tmp_x.RestoreHalfBlock(); tmp_x.RestoreHalfBlock();
UnsignedMinus(inverse_two, &tmp_x, true); UnsignedMinus(inverse_two, &tmp_x, true);
inverse_two.num_length =
((inverse_two.num_length + int2048::kNum - 1) / int2048::kNum) *
int2048::kNum;
int inverse_two_error = 0, x_error = 0; int inverse_two_error = 0, x_error = 0;
if (inverse_two.val[0] >= int2048::kNTTBlockBase) { if (inverse_two.val[0] >= int2048::kNTTBlockBase) {
inverse_two_error = 1; inverse_two_error = 1;
@ -684,8 +717,6 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
x.ProcessHalfBlock(); x.ProcessHalfBlock();
} }
UnsignedMultiply(x, &inverse_two, true); UnsignedMultiply(x, &inverse_two, true);
x.num_length =
((x.num_length + int2048::kNum - 1) / int2048::kNum) * int2048::kNum;
for (int i = 0; i < x_error + inverse_two_error; i++) x.RestoreHalfBlock(); for (int i = 0; i < x_error + inverse_two_error; i++) x.RestoreHalfBlock();
/** /**
* now x is the next x, store[tot] stores last x, store[tot^1] stores the x * now x is the next x, store[tot] stores last x, store[tot^1] stores the x
@ -694,6 +725,7 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int blocks_of_x = (x.num_length + int2048::kNum - 1) / int2048::kNum; int blocks_of_x = (x.num_length + int2048::kNum - 1) / int2048::kNum;
if (blocks_of_x > pow_A + 3) { if (blocks_of_x > pow_A + 3) {
x.ClaimMem((pow_A + 3) * int2048::kNum); x.ClaimMem((pow_A + 3) * int2048::kNum);
x.num_length = (pow_A + 3) * int2048::kNum;
blocks_of_x = pow_A + 3; blocks_of_x = pow_A + 3;
} }
bool pre_same = true, pre_pre_same = true; bool pre_same = true, pre_pre_same = true;