diff --git a/include/int2048.h b/include/int2048.h index 94164b9..509747f 100644 --- a/include/int2048.h +++ b/include/int2048.h @@ -48,7 +48,8 @@ root= 6 __int128_t QuickPow(__int128_t v, long long q); void NTTTransform(__int128_t *, int, bool); - + friend int2048 GetInv(const int2048 &,int); + public: int2048(); int2048(long long); @@ -114,7 +115,6 @@ root= 6 friend bool operator<=(const int2048 &, const int2048 &); friend bool operator>=(const int2048 &, const int2048 &); }; -int2048 GetInv(const int2048 &,int); } // namespace sjtu #endif \ No newline at end of file diff --git a/src/int2048.cpp b/src/int2048.cpp index 0fb459f..8c5fdd4 100644 --- a/src/int2048.cpp +++ b/src/int2048.cpp @@ -562,7 +562,65 @@ void int2048::UnsignedMultiplyByInt(int v) { this->num_length--; } -int2048 GetInv(const int2048 &, int) { ; } +/** + * @brief Estimate the inverse of B, which is the result of [(base)^2n/B] + */ +int2048 GetInv(const int2048 &B, int n) { + const static int kPow10[9] = {1, 10, 100, 1000, 10000, + 100000, 1000000, 10000000, 100000000}; + int total_blocks = (B.num_length + int2048::kNum - 1) / int2048::kNum; + if (n <= 2) { + long long b = B.val[total_blocks - 1] * (long long)int2048::kStoreBase + + B.val[total_blocks - 2]; + int2048 res; + res.ClaimMem((2 * n) * int2048::kNum + 1); + res.val[2 * n] = 1; + __uint128_t c = 0; + for (int i = 2 * n; i >= 0; i--) { + c = c * int2048::kStoreBase + res.val[i]; + res.val[i] = c / b; + c %= b; + assert(res.val[i] < int2048::kStoreBase); + } + res.num_length = (2 * n) * int2048::kNum + 1; + while (res.num_length > 1 && + res.val[(res.num_length - 1) / int2048::kNum] / + kPow10[(res.num_length - 1) % int2048::kNum] == + 0) + res.num_length--; + return std::move(res); + } + int k = (n + 2) >> 1; + int2048 sub_soluton = std::move(GetInv(B, k)); + int2048 sub_soluton_copy_1(sub_soluton); + int2048 sub_soluton_copy_2(sub_soluton); + sub_soluton_copy_1.UnsignedMultiplyByInt(2); + sub_soluton_copy_1.RightMoveBy((n - k) * int2048::kNum); + int2048 current_B; // current_B is the highest n blocks of B + current_B.ClaimMem(n * int2048::kNum); + for (int i = n - 1; i >= 0; i--) + current_B.val[i] = B.val[i + total_blocks - n]; + UnsignedMultiply(sub_soluton_copy_2, ¤t_B); + UnsignedMultiply(sub_soluton_copy_2, &sub_soluton); + sub_soluton_copy_2.RightMoveBy(2 * k * int2048::kNum); + UnsignedMinus(sub_soluton_copy_1, &sub_soluton_copy_2); + int2048 res = sub_soluton_copy_1; + int2048 remain; + remain.ClaimMem((2 * n) * int2048::kNum + 1); + remain.val[2 * n] = 1; + remain.num_length = (2 * n) * int2048::kNum + 1; + UnsignedMultiply(sub_soluton_copy_1, ¤t_B); + UnsignedMinus(remain, &sub_soluton_copy_1); + for (int i = 64; i > 0; i >>= 1) { + int2048 tmp_B(current_B); + tmp_B.UnsignedMultiplyByInt(i); + if(UnsignedCmp(remain, tmp_B) >= 0) { + res += i; + UnsignedMinus(remain, &tmp_B); + } + } + return std::move(res); +} inline void UnsignedDivide(int2048 &A, const int2048 *pB) { int2048 B(*pB); int L1 = (A.num_length + int2048::kNum - 1) / int2048::kNum; @@ -579,6 +637,7 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { c = c * int2048::kStoreBase + A.val[i]; A.val[i] = c / b; c %= b; + assert(A.val[i] < int2048::kStoreBase); } A.num_length = L1 * int2048::kNum; const static int kPow10[9] = {1, 10, 100, 1000, 10000, @@ -612,7 +671,13 @@ inline void UnsignedDivide(int2048 &A, const int2048 *pB) { UnsignedMinus(remain, &tmp_B); for (int i = 8; i > 0; i >>= 1) { tmp_B = B; + tmp_B.UnsignedMultiplyByInt(i); + if (UnsignedCmp(remain, tmp_B) >= 0) { + res_hat += i; + UnsignedMinus(remain, &tmp_B); + } } + A = std::move(res_hat); } int2048 &int2048::Divide(const int2048 &B) { if (this == &B) { diff --git a/tester/cases/3.py b/tester/cases/3.py index a1f1ffb..c93c5fd 100755 --- a/tester/cases/3.py +++ b/tester/cases/3.py @@ -36,7 +36,7 @@ opt_python=[] if True: for i in range(0,10): - val=randint(-10**15,10**15) + val=randint(-10**32,10**32) if i==0: val=randint(-10**100,10**100) opt_cpp.append("a_"+str(i)+"=int2048(\""+str(val)+"\");")