upd: fix some bug

This commit is contained in:
2023-10-31 16:35:28 +08:00
parent fd6e5e208e
commit bcec853fe6

View File

@ -149,6 +149,12 @@ void int2048::print() {
delete[] buf;
}
/**
* @brief Claim memory for the number.
*
* @details warning: ClaimMem doesn't change num_length, so you should change it
* manually.
*/
void int2048::ClaimMem(size_t number_length) {
size_t new_number_blocks = (number_length + kNum - 1) / kNum;
if (new_number_blocks > buf_length) {
@ -178,8 +184,8 @@ inline void UnsignedMinus(int2048 &, const int2048 *, bool inverse = false);
inline void UnsignedAdd(int2048 &A, const int2048 *const pB,
bool inverse = false) {
if (&A == pB) throw "UnsignedAdd: A and B are the same object";
A.ClaimMem(std::max(A.num_length, pB->num_length) + 2);
if (!inverse) {
A.ClaimMem(std::max(A.num_length, pB->num_length) + 2);
for (int i = 0;
i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) /
int2048::kNum;
@ -189,26 +195,31 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB,
if (i + 1 < A.buf_length) A.val[i + 1] += A.val[i] / int2048::kStoreBase;
A.val[i] %= int2048::kStoreBase;
}
A.num_length = std::max(A.num_length, pB->num_length);
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
if (A.val[A.num_length / int2048::kNum] /
kPow10[A.num_length % int2048::kNum] >
0)
A.num_length++;
} else {
for (int i = (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) /
int2048::kNum -
1;
assert(("this code shouldn't be executed", 0));
assert(A.num_length % int2048::kNum == 0);
assert(pB->num_length % int2048::kNum == 0);
A.ClaimMem(std::max(A.num_length, pB->num_length));
A.num_length = std::max(A.num_length, pB->num_length);
for (int i = std::max(A.num_length, pB->num_length) / int2048::kNum - 1;
i >= 0; i--) {
if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum)
A.val[i] += pB->val[i];
if (i < pB->num_length / int2048::kNum) A.val[i] += pB->val[i];
if (A.val[i] >= int2048::kStoreBase && i - 1 >= 0) {
A.val[i - 1] += A.val[i] / int2048::kStoreBase;
A.val[i] %= int2048::kStoreBase;
}
}
while (A.num_length > int2048::kNum &&
A.val[A.num_length / int2048::kNum - 1] == 0)
A.num_length -= int2048::kNum;
}
A.num_length = std::max(A.num_length, pB->num_length);
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
if (A.val[A.num_length / int2048::kNum] /
kPow10[A.num_length % int2048::kNum] >
0)
A.num_length++;
}
// 加上一个大整数
@ -265,11 +276,23 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) {
A.val[i + 1]--;
}
}
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
int new_length = 0;
for (int i = 0; i < A.num_length; i++)
if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0)
new_length = i + 1;
A.num_length = new_length;
if (A.num_length == 0) A.num_length = 1;
A.ClaimMem(A.num_length);
} else {
int blocks_A = (A.num_length + int2048::kNum - 1) / int2048::kNum;
int blocks_B = (pB->num_length + int2048::kNum - 1) / int2048::kNum;
assert(A.num_length % int2048::kNum == 0);
assert(pB->num_length % int2048::kNum == 0);
int blocks_A = A.num_length / int2048::kNum;
int blocks_B = pB->num_length / int2048::kNum;
if (blocks_A < blocks_B) {
A.ClaimMem(blocks_B * int2048::kNum);
A.num_length = blocks_B * int2048::kNum;
blocks_A = blocks_B;
}
for (int i = (pB->num_length + int2048::kNum - 1) / int2048::kNum - 1;
@ -280,16 +303,11 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB, bool inverse) {
A.val[i - 1]--;
}
}
while (A.num_length > int2048::kNum &&
A.val[A.num_length / int2048::kNum - 1] == 0)
A.num_length -= int2048::kNum;
A.ClaimMem(A.num_length);
}
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
int new_length = 0;
for (int i = 0; i < A.num_length; i++)
if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0)
new_length = i + 1;
A.num_length = new_length;
if (A.num_length == 0) A.num_length = 1;
A.ClaimMem(A.num_length);
}
// 减去一个大整数
@ -488,7 +506,7 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum);
int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum);
int max_blocks = blocks_of_A + blocks_of_B;
int NTT_blocks = 1;
int NTT_blocks = 2;
while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1;
__int128_t *pDA = new __int128_t[NTT_blocks]();
__int128_t *pDB = new __int128_t[NTT_blocks]();
@ -503,6 +521,8 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlockBase;
}
} else {
assert(A.num_length % int2048::kNum == 0);
assert(pB->num_length % int2048::kNum == 0);
pDA[0] = A.val[0];
for (int i = 1; i < blocks_of_A; i++) {
pDA[i << 1] = A.val[i] % int2048::kNTTBlockBase;
@ -548,16 +568,23 @@ inline void UnsignedMultiply(int2048 &A, const int2048 *pB,
}
}
A.num_length = NTT_blocks * 4;
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
while (A.val[(A.num_length - 1) / int2048::kNum] /
kPow10[(A.num_length - 1) % int2048::kNum] ==
0) {
A.num_length--;
if (A.num_length == 0) {
A.num_length = 1;
break;
if (!inverse) {
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
while (A.val[(A.num_length - 1) / int2048::kNum] /
kPow10[(A.num_length - 1) % int2048::kNum] ==
0) {
A.num_length--;
if (A.num_length == 0) {
A.num_length = 1;
break;
}
}
} else {
while (A.num_length > int2048::kNum &&
A.val[A.num_length / int2048::kNum - 1] == 0)
A.num_length -= int2048::kNum;
A.ClaimMem(A.num_length);
}
delete[] pDA;
delete[] pDB;
@ -596,8 +623,11 @@ int2048 operator*(int2048 A, const int2048 &B) {
return std::move(A);
}
void int2048::ProcessHalfBlock() {
this->ClaimMem(this->num_length + int2048::kNTTBlockBase);
int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum;
assert(this->num_length % int2048::kNum == 0);
this->ClaimMem(this->num_length + int2048::kNum);
this->num_length = this->num_length + int2048::kNum;
assert(this->num_length % int2048::kNum == 0);
int blocks_num = this->num_length / int2048::kNum;
for (int i = blocks_num - 1; i >= 1; i--) {
val[i] /= int2048::kNTTBlockBase;
val[i] += (val[i - 1] % int2048::kNTTBlockBase) * int2048::kNTTBlockBase;
@ -605,13 +635,16 @@ void int2048::ProcessHalfBlock() {
val[0] /= int2048::kNTTBlockBase;
}
void int2048::RestoreHalfBlock() {
int blocks_num = (this->num_length + int2048::kNum - 1) / int2048::kNum;
assert(this->num_length % int2048::kNum == 0);
int blocks_num = this->num_length / int2048::kNum;
for (int i = 0; i < blocks_num - 1; i++) {
val[i] *= int2048::kNTTBlockBase;
val[i] %= int2048::kStoreBase;
val[i] += val[i + 1] / int2048::kNTTBlockBase;
}
(val[blocks_num - 1] *= int2048::kNTTBlockBase) %= int2048::kStoreBase;
while (this->num_length > 0 && val[this->num_length / int2048::kNum - 1] == 0)
this->num_length -= int2048::kNum;
}
inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int L1 = A.num_length, L2 = pB->num_length;
@ -638,12 +671,15 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int pow_B = (L2 + int2048::kNum - 1) / int2048::kNum - 1;
// pow_B+1 is the number of blocks (with number) of B'
int2048 inverse_B(*pB);
inverse_B.num_length = (inverse_B.num_length + int2048::kNum - 1) /
int2048::kNum * int2048::kNum;
for (int i = 0; (i << 1) < (pow_B + 1); i++)
std::swap(inverse_B.val[i], inverse_B.val[pow_B - i]);
int2048 x(
int2048::kStoreBase *
(long long)std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1)));
assert(x.val[1] == std::max(1, int2048::kStoreBase / (inverse_B.val[0] + 1)));
x.num_length = 2 * int2048::kNum;
int *store[2];
store[0] = new int[pow_A + 5]();
store[1] = new int[pow_A + 5]();
@ -659,21 +695,18 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
}
while (true) {
int2048 inverse_two(2), tmp_x(x);
inverse_two.num_length = int2048::kNum;
int tmp_x_error = 0;
if (tmp_x.val[0] >= int2048::kNTTBlockBase) {
tmp_x_error = 1;
tmp_x.ProcessHalfBlock();
}
assert(tmp_x.num_length % int2048::kNum == 0);
assert(inverse_B.num_length % int2048::kNum == 0);
UnsignedMultiply(tmp_x, &inverse_B, true);
tmp_x.num_length =
((tmp_x.num_length + int2048::kNum - 1) / int2048::kNum) *
int2048::kNum;
for (int i = 0; i < tmp_x_error + inverseB_error; i++)
tmp_x.RestoreHalfBlock();
UnsignedMinus(inverse_two, &tmp_x, true);
inverse_two.num_length =
((inverse_two.num_length + int2048::kNum - 1) / int2048::kNum) *
int2048::kNum;
int inverse_two_error = 0, x_error = 0;
if (inverse_two.val[0] >= int2048::kNTTBlockBase) {
inverse_two_error = 1;
@ -684,8 +717,6 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
x.ProcessHalfBlock();
}
UnsignedMultiply(x, &inverse_two, true);
x.num_length =
((x.num_length + int2048::kNum - 1) / int2048::kNum) * int2048::kNum;
for (int i = 0; i < x_error + inverse_two_error; i++) x.RestoreHalfBlock();
/**
* now x is the next x, store[tot] stores last x, store[tot^1] stores the x
@ -694,6 +725,7 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) {
int blocks_of_x = (x.num_length + int2048::kNum - 1) / int2048::kNum;
if (blocks_of_x > pow_A + 3) {
x.ClaimMem((pow_A + 3) * int2048::kNum);
x.num_length = (pow_A + 3) * int2048::kNum;
blocks_of_x = pow_A + 3;
}
bool pre_same = true, pre_pre_same = true;