upd: first version of *

This commit is contained in:
2023-10-30 16:56:01 +08:00
parent 116e675b29
commit 1e543b6a7b
2 changed files with 134 additions and 17 deletions

View File

@ -15,14 +15,39 @@ class int2048 {
* num_length is the length of the integer, (num_length+kNum-1)/kNum is the
* length of val with data. Note that position in val without data is 0.
*/
size_t buf_length = 0;
const static int kMod = 100000000, kNum = 8, kDefaultLength = 10;
const static int kMemAdditionScalar = 2, kMemDeleteScalar = 4;
/**
* the follow data used by NTT is generated by this code:
#!/usr/bin/python3
from sympy import isprime,primitive_root
found=False
for i in range(0,20):
for j in range(2**i,(2**(i+1))):
V=j*(2**(57-i))+1
if isprime(V):
found=True
print(j,57-i)
print("number=",V)
print("root=",primitive_root(V))
exit(0)
* it out puts:
95 55
number= 180143985094819841
root= 6
*/
const static __int128_t kNTTMod = 180143985094819841ll;
const static __int128_t kNTTRoot = 6;
const static int kNTTBlockNum = 4;
const static int kNTTBlcokBase = 10000;
size_t buf_length = 0;
int *val = nullptr;
signed char flag = +1;
int num_length = 0;
void NTT(__int128_t *, int, int);
__int128_t QuickPow(__int128_t v, long long q);
void NTTTransform(__int128_t *, int, bool);
public:
int2048();

View File

@ -26,6 +26,9 @@
#include <cstdio>
#include <cstring>
#include <utility>
static_assert(sizeof(int) == 4, "sizeof(int) != 4");
static_assert(sizeof(long long) == 8, "sizeof(long long)!=8");
namespace sjtu {
// 构造函数
int2048::int2048() {
@ -155,7 +158,7 @@ void int2048::ClaimMem(size_t number_length) {
inline int UnsignedCmp(const int2048 &A, const int2048 &B) {
if (A.num_length != B.num_length) return A.num_length < B.num_length ? -1 : 1;
int number_of_blocks = (A.num_length + A.kNum - 1) / A.kNum;
int number_of_blocks = (A.num_length + int2048::kNum - 1) / int2048::kNum;
for (int i = number_of_blocks - 1; i >= 0; i--)
if (A.val[i] != B.val[i]) return A.val[i] < B.val[i] ? -1 : 1;
return 0;
@ -165,16 +168,20 @@ inline void UnsignedAdd(int2048 &A, const int2048 *const pB) {
if (&A == pB) throw "UnsignedAdd: A and B are the same object";
A.ClaimMem(std::max(A.num_length, pB->num_length) + 2);
for (int i = 0;
i < (std::max(A.num_length, pB->num_length) + A.kNum - 1) / A.kNum;
i < (std::max(A.num_length, pB->num_length) + int2048::kNum - 1) /
int2048::kNum;
i++) {
if (i < (pB->num_length + pB->kNum - 1) / pB->kNum) A.val[i] += pB->val[i];
A.val[i + 1] += A.val[i] / A.kMod;
A.val[i] %= A.kMod;
if (i < (pB->num_length + int2048::kNum - 1) / int2048::kNum)
A.val[i] += pB->val[i];
A.val[i + 1] += A.val[i] / int2048::kMod;
A.val[i] %= int2048::kMod;
}
A.num_length = std::max(A.num_length, pB->num_length);
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
if (A.val[A.num_length / A.kNum] / kPow10[A.num_length % A.kNum] > 0)
if (A.val[A.num_length / int2048::kNum] /
kPow10[A.num_length % int2048::kNum] >
0)
A.num_length++;
}
@ -222,10 +229,11 @@ int2048 add(int2048 A, const int2048 &B) {
inline void UnsignedMinus(int2048 &A, const int2048 *const pB) {
if (&A == pB) throw "UnsignedMinus: A and B are the same object";
for (int i = 0; i < (pB->num_length + A.kNum - 1) / A.kNum; i++) {
for (int i = 0; i < (pB->num_length + int2048::kNum - 1) / int2048::kNum;
i++) {
A.val[i] -= pB->val[i];
if (A.val[i] < 0) {
A.val[i] += A.kMod;
A.val[i] += int2048::kMod;
A.val[i + 1]--;
}
}
@ -233,7 +241,8 @@ inline void UnsignedMinus(int2048 &A, const int2048 *const pB) {
100000, 1000000, 10000000, 100000000};
int new_length = 0;
for (int i = 0; i < A.num_length; i++)
if (A.val[i / A.kNum] / kPow10[i % A.kNum] > 0) new_length = i + 1;
if (A.val[i / int2048::kNum] / kPow10[i % int2048::kNum] > 0)
new_length = i + 1;
A.num_length = new_length;
if (A.num_length == 0) A.num_length = 1;
A.ClaimMem(A.num_length);
@ -329,10 +338,92 @@ int2048 operator-(int2048 A, const int2048 &B) {
A.minus(B);
return std::move(A);
}
inline void UnsignedMultiply(int2048 &A, const int2048 *pB)
{
;
__int128_t int2048::QuickPow(__int128_t v, long long q) {
__int128_t ret = 1;
v %= int2048::kNTTMod;
while (q > 0) {
if (q & 1) (ret *= v) %= int2048::kNTTMod;
(v *= v) %= int2048::kNTTMod;
q >>= 1;
}
return ret;
}
void int2048::NTTTransform(__int128_t *a, int NTT_blocks,
bool inverse = false) {
for (int i = 1, j = 0; i < NTT_blocks; i++) {
int bit = NTT_blocks >> 1;
while (j >= bit) {
j -= bit;
bit >>= 1;
}
j += bit;
if (i < j) std::swap(a[i], a[j]);
}
for (int len = 2; len <= NTT_blocks; len <<= 1) {
__int128_t wlen = QuickPow(int2048::kNTTRoot, (int2048::kNTTMod - 1) / len);
if (inverse) wlen = QuickPow(wlen, int2048::kNTTMod - 2);
for (int i = 0; i < NTT_blocks; i += len) {
__int128_t w = 1;
for (int j = 0; j < len / 2; j++) {
__int128_t u = a[i + j], v = a[i + j + len / 2] * w % int2048::kNTTMod;
a[i + j] = (u + v) % int2048::kNTTMod;
a[i + j + len / 2] = (u - v + int2048::kNTTMod) % int2048::kNTTMod;
(w *= wlen) %= int2048::kNTTMod;
}
}
}
if (inverse) {
__int128_t inv = QuickPow(NTT_blocks, int2048::kNTTMod - 2);
for (int i = 0; i < NTT_blocks; i++) (a[i] *= inv) %= int2048::kNTTMod;
}
}
inline void UnsignedMultiply(int2048 &A, const int2048 *pB) {
if (&A == pB) throw "UnsignedMultiply: A and B are the same object";
int blocks_of_A = ((A.num_length + int2048::kNum - 1) / int2048::kNum);
int blocks_of_B = ((pB->num_length + int2048::kNum - 1) / int2048::kNum);
int max_blocks = blocks_of_A + blocks_of_B;
int NTT_blocks = 1;
while (NTT_blocks < (max_blocks << 1)) NTT_blocks <<= 1;
__int128_t *pDA = new __int128_t[NTT_blocks]();
__int128_t *pDB = new __int128_t[NTT_blocks]();
__int128_t *pDC = new __int128_t[NTT_blocks]();
for (int i = 0; i < blocks_of_A; i++) {
pDA[i << 1] = A.val[i] % int2048::kNTTBlcokBase;
pDA[(i << 1) | 1] = A.val[i] / int2048::kNTTBlcokBase;
}
for (int i = 0; i < blocks_of_B; i++) {
pDB[i << 1] = pB->val[i] % int2048::kNTTBlcokBase;
pDB[(i << 1) | 1] = pB->val[i] / int2048::kNTTBlcokBase;
}
A.NTTTransform(pDA, NTT_blocks);
A.NTTTransform(pDB, NTT_blocks);
for (int i = 0; i < NTT_blocks; i++)
pDC[i] = (pDA[i] * pDB[i]) % int2048::kNTTMod;
A.NTTTransform(pDC, NTT_blocks, true);
for (int i = 0; i < NTT_blocks - 1; i++) {
pDC[i + 1] += pDC[i] / int2048::kNTTBlcokBase;
pDC[i] %= int2048::kNTTBlcokBase;
}
if (pDC[NTT_blocks - 1] >= int2048::kNTTBlcokBase)
throw "UnsignedMultiply: NTT result overflow";
int flag_store = A.flag;
A.ClaimMem(NTT_blocks * 4);
memset(A.val, 0, A.buf_length * sizeof(int));
for (int i = 0; i < NTT_blocks / 2; i++) {
A.val[i] = pDC[(i << 1) | 1] * int2048::kNTTBlcokBase + pDC[i << 1];
}
A.num_length = NTT_blocks * 4;
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
while (A.val[(A.num_length - 1) / int2048::kNum] /
kPow10[(A.num_length - 1) % int2048::kNum] ==
0) {
A.num_length--;
if (A.num_length == 0) throw "UnsignedMultiply: num_length==0";
}
delete[] pDA;
delete[] pDB;
delete[] pDC;
}
int2048 &int2048::Multiply(const int2048 &B) {
@ -345,7 +436,7 @@ int2048 &int2048::Multiply(const int2048 &B) {
return *this;
}
this->flag = this->flag * pB->flag;
UnsignedMultiply(*this,pB);
UnsignedMultiply(*this, pB);
return *this;
}
@ -395,7 +486,8 @@ std::ostream &operator<<(std::ostream &stream, const int2048 &v) {
const static int kPow10[9] = {1, 10, 100, 1000, 10000,
100000, 1000000, 10000000, 100000000};
for (int i = v.num_length - 1; i >= 0; i--)
stream << char('0' + v.val[i / v.kNum] / kPow10[i % v.kNum] % 10);
stream << char('0' +
v.val[i / int2048::kNum] / kPow10[i % int2048::kNum] % 10);
return stream;
}