On this page
article
NTT
Sobre
Precisa do mint (primitivas de aritmetica modular)
O(n log (n))
Link original: ntt.cpp
Código
const int MOD = 998244353;
typedef mod_int<MOD> mint;
void ntt(vector<mint>& a, bool rev) {
int n = a.size(); auto b = a;
assert(!(n&(n-1)));
mint g = 1;
while ((g^(MOD / 2)) == 1) g += 1;
if (rev) g = 1 / g;
for (int step = n / 2; step; step /= 2) {
mint w = g^(MOD / (n / step)), wn = 1;
for (int i = 0; i < n/2; i += step) {
for (int j = 0; j < step; j++) {
auto u = a[2 * i + j], v = wn * a[2 * i + j + step];
b[i+j] = u + v; b[i + n/2 + j] = u - v;
}
wn = wn * w;
}
swap(a, b);
}
if (rev) {
auto n1 = mint(1) / n;
for (auto& x : a) x *= n1;
}
}
vector<mint> convolution(const vector<mint>& a, const vector<mint>& b) {
vector<mint> l(a.begin(), a.end()), r(b.begin(), b.end());
int N = l.size()+r.size()-1, n = 1;
while (n <= N) n *= 2;
l.resize(n);
r.resize(n);
ntt(l, false);
ntt(r, false);
for (int i = 0; i < n; i++) l[i] *= r[i];
ntt(l, true);
l.resize(N);
return l;
}