Blog of RuSun

\begin {array}{c} \mathfrak {One Problem Is Difficult} \\\\ \mathfrak {Because You Don't Know} \\\\ \mathfrak {Why It Is Diffucult} \end {array}

分治NTT模板

求解一般卷积不能直接解决的问题。

给定 $g$ ,$f _ i = \displaystyle \sum _ {j = 1} ^ i = f _ {i - j} g _ j$ ,求 $f$ 。

考虑 CDQ 分治,每次 $mid$ 左边的向右边的贡献答案。复杂度 $O(n \log ^ 2 n)$ 。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
template <class Type>
void read (Type &x)
{
char c;
bool flag = false;
while ((c = getchar()) < '0' || c > '9')
(c == '-') && (flag = true);
x = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
x = (x << 1) + (x << 3) + c - '0';
flag && (x = ~x + 1);
}
const int N = 5e5 + 10, mod = 998244353;
int rev[N];
int binpow (int b, int k)
{
int res = 1;
while (k)
{
(k & 1) && (res = (LL) res * b % mod);
b = (LL) b * b % mod;
k >>= 1;
}
return res;
}
void PolyRev (int bit)
{
for (int i = 1; i < 1 << bit; i++)
rev[i] = rev[i >> 1] >> 1 | ((i & 1) << bit - 1);
}
void ntt (int *x, int bit, int op)
{
PolyRev(bit);
int tot = 1 << bit;
for (int i = 0; i < tot; i++)
rev[i] < i && (swap(x[rev[i]], x[i]), 0);
for (int mid = 1; mid < tot; mid <<= 1)
{
int w1 = binpow(3, (mod - 1) / (mid << 1));
op == -1 && (w1 = binpow(w1, mod - 2));
for (int i = 0; i < tot; i += mid << 1)
for (int j = 0, cur = 1; j < mid; j++, cur = (LL) cur * w1 % mod)
{
int p = x[i + j], q = (LL)cur * x[i + j + mid] % mod;
x[i + j] = (p + q) % mod, x[i + j + mid] = (p - q) % mod;
}
}
if (~op)
return;
int itot = binpow(tot, mod - 2);
for (int i = 0; i < tot; i++)
x[i] = (LL) x[i] * itot % mod;
}
void solve (int *f, int *g, int l, int r)
{
if (l == r)
return f[l] += l == 0, void();
int mid = l + r >> 1;
solve(f, g, l, mid);
static int c[N], d[N];
int bit = 0, lena = mid - l + 1, lenb = r - l;
while (1 << bit < lena + lenb)
bit++;
for (int i = 0; i < 1 << bit; i++)
c[i] = i < lena ? f[i + l] : 0;
for (int i = 0; i < 1 << bit; i++)
d[i] = i < lenb ? g[i + 1] : 0;
ntt(c, bit, 1), ntt(d, bit, 1);
for (int i = 0; i < 1 << bit; i++)
c[i] = (LL) c[i] * d[i] % mod;
ntt(c, bit, -1);
for (int i = mid + 1; i <= r; i++)
(f[i] += c[i - l - 1]) %= mod;
solve(f, g, mid + 1, r);
}
int main ()
{
static int len, A[N], B[N];
read(len);
for (int i = 1; i < len; i++)
read(A[i]);
solve(B, A, 0, len - 1);
for (int i = 0; i < len; i++)
printf("%d ", (B[i] + mod) % mod);
return 0;
}

另有生成函数做法。


$$
F (x) = \sum _ i f _ i x ^ i
$$

$$
G(x) = \sum _ i g _ i x ^ i
$$


$$
F * G (x) = \sum _ i x ^ i \sum _ {j + k = i} f _ j g _ k = F(x) - f _ 0 x ^ 0
$$

$$
F * G(x) = F(x) - 1 \pmod {x ^ n}
$$

$$
F(x) = (1 - G(x)) ^ {-1} \pmod {x ^ n}
$$

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
template <class Type>
void read(Type &x)
{
char c;
bool flag = false;
while ((c = getchar()) < '0' || c > '9')
c == '-' && (flag = true);
x = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
x = (x << 3) + (x << 1) + c - '0';
flag && (x = ~x + 1);
}
template <class Type>
void write(Type x)
{
x < 0 && (putchar('-'), x = ~x + 1);
x > 9 && (write(x / 10), 0);
putchar(x % 10 + '0');
}
const int N = 5e5 + 10, mod = 998244353;
int rev[N];
int binpow(int b, int k)
{
int res = 1;
while (k)
{
if (k & 1)
res = (LL)res * b % mod;
b = (LL)b * b % mod;
k >>= 1;
}
return res;
}
void PolyRev(int bit)
{
for (int i = 1; i < (1 << bit); i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
}
void ntt(int *x, int bit, int op)
{
PolyRev(bit);
int tot = 1 << bit;
for (int i = 0; i < tot; i++)
rev[i] > i && (swap(x[i], x[rev[i]]), 0);
for (int mid = 1; mid < tot; mid <<= 1)
{
int w1 = binpow(3, (mod - 1) / (mid << 1));
op == -1 && (w1 = binpow(w1, mod - 2));
for (int i = 0; i < tot; i += (mid << 1))
for (int j = 0, cur = 1; j < mid; j++, cur = (LL)cur * w1 % mod)
{
int p = x[i + j], q = (LL)cur * x[i + j + mid] % mod;
x[i + j] = (p + q) % mod, x[i + j + mid] = (p - q) % mod;
}
}
if (~op)
return;
int itot = binpow(tot, mod - 2);
for (int i = 0; i < tot; i++)
x[i] = (LL)x[i] * itot % mod;
}
void PolyInv(int *x, int *g, int len)
{
if (len == 1)
return g[0] = binpow(x[0], mod - 2), void();
PolyInv(x, g, len + 1 >> 1);
int bit = 0;
while ((1 << bit) < (len << 1))
bit++;
int tot = 1 << bit;
static int c[N];
for (int i = 0; i < tot; i++)
c[i] = i < len ? x[i] : 0;
ntt(g, bit, 1), ntt(c, bit, 1);
for (int i = 0; i < tot; i++)
g[i] = (2 - (LL)g[i] * c[i]) % mod * g[i] % mod;
ntt(g, bit, -1);
for (int i = len; i < tot; i++)
g[i] = 0;
}
int main()
{
static int n, f[N], g[N];
read(n);
for (int i = 1; i < n; i++)
read(g[i]), g[i] = ~g[i] + 1;
g[0] = 1;
PolyInv(g, f, n);
for (int i = 0; i < n; i++)
write((f[i] + mod) % mod), putchar(' ');
return 0;
}