@@ -14,95 +14,169 @@ namespace atcoder {
14
14
15
15
namespace internal {
16
16
17
+ template <class mint ,
18
+ int g = internal::primitive_root<mint::mod()>,
19
+ internal::is_static_modint_t <mint>* = nullptr >
20
+ struct fft_info {
21
+ static constexpr int rank2 = bsf_constexpr(mint::mod() - 1 );
22
+ std::array<mint, rank2 + 1 > root; // root[i]^(2^i) == 1
23
+ std::array<mint, rank2 + 1 > iroot; // root[i] * iroot[i] == 1
24
+
25
+ std::array<mint, std::max(0 , rank2 - 2 + 1 )> rate2;
26
+ std::array<mint, std::max(0 , rank2 - 2 + 1 )> irate2;
27
+
28
+ std::array<mint, std::max(0 , rank2 - 3 + 1 )> rate3;
29
+ std::array<mint, std::max(0 , rank2 - 3 + 1 )> irate3;
30
+
31
+ fft_info () {
32
+ root[rank2] = mint (g).pow ((mint::mod () - 1 ) >> rank2);
33
+ iroot[rank2] = root[rank2].inv ();
34
+ for (int i = rank2 - 1 ; i >= 0 ; i--) {
35
+ root[i] = root[i + 1 ] * root[i + 1 ];
36
+ iroot[i] = iroot[i + 1 ] * iroot[i + 1 ];
37
+ }
38
+
39
+ {
40
+ mint prod = 1 , iprod = 1 ;
41
+ for (int i = 0 ; i <= rank2 - 2 ; i++) {
42
+ rate2[i] = root[i + 2 ] * prod;
43
+ irate2[i] = iroot[i + 2 ] * iprod;
44
+ prod *= iroot[i + 2 ];
45
+ iprod *= root[i + 2 ];
46
+ }
47
+ }
48
+ {
49
+ mint prod = 1 , iprod = 1 ;
50
+ for (int i = 0 ; i <= rank2 - 3 ; i++) {
51
+ rate3[i] = root[i + 3 ] * prod;
52
+ irate3[i] = iroot[i + 3 ] * iprod;
53
+ prod *= iroot[i + 3 ];
54
+ iprod *= root[i + 3 ];
55
+ }
56
+ }
57
+ }
58
+ };
59
+
17
60
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
18
61
void butterfly (std::vector<mint>& a) {
19
- static constexpr int g = internal::primitive_root<mint::mod ()>;
20
62
int n = int (a.size ());
21
63
int h = internal::ceil_pow2 (n);
22
64
23
- static bool first = true ;
24
- static mint sum_e[30 ]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
25
- if (first) {
26
- first = false ;
27
- mint es[30 ], ies[30 ]; // es[i]^(2^(2+i)) == 1
28
- int cnt2 = bsf (mint::mod () - 1 );
29
- mint e = mint (g).pow ((mint::mod () - 1 ) >> cnt2), ie = e.inv ();
30
- for (int i = cnt2; i >= 2 ; i--) {
31
- // e^(2^i) == 1
32
- es[i - 2 ] = e;
33
- ies[i - 2 ] = ie;
34
- e *= e;
35
- ie *= ie;
36
- }
37
- mint now = 1 ;
38
- for (int i = 0 ; i <= cnt2 - 2 ; i++) {
39
- sum_e[i] = es[i] * now;
40
- now *= ies[i];
41
- }
42
- }
43
- for (int ph = 1 ; ph <= h; ph++) {
44
- int w = 1 << (ph - 1 ), p = 1 << (h - ph);
45
- mint now = 1 ;
46
- for (int s = 0 ; s < w; s++) {
47
- int offset = s << (h - ph + 1 );
48
- for (int i = 0 ; i < p; i++) {
49
- auto l = a[i + offset];
50
- auto r = a[i + offset + p] * now;
51
- a[i + offset] = l + r;
52
- a[i + offset + p] = l - r;
65
+ static const fft_info<mint> info;
66
+
67
+ int len = 0 ; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
68
+ while (len < h) {
69
+ if (h - len == 1 ) {
70
+ int p = 1 << (h - len - 1 );
71
+ mint rot = 1 ;
72
+ for (int s = 0 ; s < (1 << len); s++) {
73
+ int offset = s << (h - len);
74
+ for (int i = 0 ; i < p; i++) {
75
+ auto l = a[i + offset];
76
+ auto r = a[i + offset + p] * rot;
77
+ a[i + offset] = l + r;
78
+ a[i + offset + p] = l - r;
79
+ }
80
+ if (s + 1 != (1 << len))
81
+ rot *= info.rate2 [bsf (~(unsigned int )(s))];
82
+ }
83
+ len++;
84
+ } else {
85
+ // 4-base
86
+ int p = 1 << (h - len - 2 );
87
+ mint rot = 1 , imag = info.root [2 ];
88
+ for (int s = 0 ; s < (1 << len); s++) {
89
+ mint rot2 = rot * rot;
90
+ mint rot3 = rot2 * rot;
91
+ int offset = s << (h - len);
92
+ for (int i = 0 ; i < p; i++) {
93
+ auto mod2 = 1ULL * mint::mod () * mint::mod ();
94
+ auto a0 = 1ULL * a[i + offset].val ();
95
+ auto a1 = 1ULL * a[i + offset + p].val () * rot.val ();
96
+ auto a2 = 1ULL * a[i + offset + 2 * p].val () * rot2.val ();
97
+ auto a3 = 1ULL * a[i + offset + 3 * p].val () * rot3.val ();
98
+ auto a1na3imag =
99
+ 1ULL * mint (a1 + mod2 - a3).val () * imag.val ();
100
+ auto na2 = mod2 - a2;
101
+ a[i + offset] = a0 + a2 + a1 + a3;
102
+ a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
103
+ a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
104
+ a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
105
+ }
106
+ if (s + 1 != (1 << len))
107
+ rot *= info.rate3 [bsf (~(unsigned int )(s))];
53
108
}
54
- now *= sum_e[ bsf (~( unsigned int )(s))] ;
109
+ len += 2 ;
55
110
}
56
111
}
57
112
}
58
113
59
114
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
60
115
void butterfly_inv (std::vector<mint>& a) {
61
- static constexpr int g = internal::primitive_root<mint::mod ()>;
62
116
int n = int (a.size ());
63
117
int h = internal::ceil_pow2 (n);
64
118
65
- static bool first = true ;
66
- static mint sum_ie[30 ]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
67
- if (first) {
68
- first = false ;
69
- mint es[30 ], ies[30 ]; // es[i]^(2^(2+i)) == 1
70
- int cnt2 = bsf (mint::mod () - 1 );
71
- mint e = mint (g).pow ((mint::mod () - 1 ) >> cnt2), ie = e.inv ();
72
- for (int i = cnt2; i >= 2 ; i--) {
73
- // e^(2^i) == 1
74
- es[i - 2 ] = e;
75
- ies[i - 2 ] = ie;
76
- e *= e;
77
- ie *= ie;
78
- }
79
- mint now = 1 ;
80
- for (int i = 0 ; i <= cnt2 - 2 ; i++) {
81
- sum_ie[i] = ies[i] * now;
82
- now *= es[i];
83
- }
84
- }
119
+ static const fft_info<mint> info;
120
+
121
+ int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
122
+ while (len) {
123
+ if (len == 1 ) {
124
+ int p = 1 << (h - len);
125
+ mint irot = 1 ;
126
+ for (int s = 0 ; s < (1 << (len - 1 )); s++) {
127
+ int offset = s << (h - len + 1 );
128
+ for (int i = 0 ; i < p; i++) {
129
+ auto l = a[i + offset];
130
+ auto r = a[i + offset + p];
131
+ a[i + offset] = l + r;
132
+ a[i + offset + p] =
133
+ (unsigned long long )(mint::mod () + l.val () - r.val ()) *
134
+ irot.val ();
135
+ ;
136
+ }
137
+ if (s + 1 != (1 << (len - 1 )))
138
+ irot *= info.irate2 [bsf (~(unsigned int )(s))];
139
+ }
140
+ len--;
141
+ } else {
142
+ // 4-base
143
+ int p = 1 << (h - len);
144
+ mint irot = 1 , iimag = info.iroot [2 ];
145
+ for (int s = 0 ; s < (1 << (len - 2 )); s++) {
146
+ mint irot2 = irot * irot;
147
+ mint irot3 = irot2 * irot;
148
+ int offset = s << (h - len + 2 );
149
+ for (int i = 0 ; i < p; i++) {
150
+ auto a0 = 1ULL * a[i + offset + 0 * p].val ();
151
+ auto a1 = 1ULL * a[i + offset + 1 * p].val ();
152
+ auto a2 = 1ULL * a[i + offset + 2 * p].val ();
153
+ auto a3 = 1ULL * a[i + offset + 3 * p].val ();
154
+
155
+ auto a2na3iimag =
156
+ 1ULL *
157
+ mint ((mint::mod () + a2 - a3) * iimag.val ()).val ();
85
158
86
- for ( int ph = h; ph >= 1 ; ph--) {
87
- int w = 1 << (ph - 1 ), p = 1 << (h - ph);
88
- mint inow = 1 ;
89
- for ( int s = 0 ; s < w; s++) {
90
- int offset = s << (h - ph + 1 );
91
- for ( int i = 0 ; i < p; i++) {
92
- auto l = a[i + offset];
93
- auto r = a[i + offset + p];
94
- a[i + offset] = l + r ;
95
- a[i + offset + p] =
96
- ( unsigned long long )( mint::mod () + l. val () - r. val ()) *
97
- inow. val () ;
159
+ a[i + offset] = a0 + a1 + a2 + a3;
160
+ a[i + offset + 1 * p] =
161
+ (a0 + ( mint::mod () - a1) + a2na3iimag) * irot. val () ;
162
+ a[i + offset + 2 * p] =
163
+ (a0 + a1 + ( mint::mod () - a2) + ( mint::mod () - a3)) *
164
+ irot2. val ();
165
+ a[i + offset + 3 * p] =
166
+ (a0 + ( mint::mod () - a1) + ( mint::mod () - a2na3iimag)) *
167
+ irot3. val () ;
168
+ }
169
+ if (s + 1 != ( 1 << (len - 2 )))
170
+ irot *= info. irate3 [ bsf (~( unsigned int )(s))] ;
98
171
}
99
- inow *= sum_ie[ bsf (~( unsigned int )(s))] ;
172
+ len -= 2 ;
100
173
}
101
174
}
102
175
}
103
176
104
177
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
105
- std::vector<mint> convolution_naive (const std::vector<mint>& a, const std::vector<mint>& b) {
178
+ std::vector<mint> convolution_naive (const std::vector<mint>& a,
179
+ const std::vector<mint>& b) {
106
180
int n = int (a.size ()), m = int (b.size ());
107
181
std::vector<mint> ans (n + m - 1 );
108
182
if (n < m) {
@@ -150,7 +224,8 @@ std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
150
224
}
151
225
152
226
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
153
- std::vector<mint> convolution (const std::vector<mint>& a, const std::vector<mint>& b) {
227
+ std::vector<mint> convolution (const std::vector<mint>& a,
228
+ const std::vector<mint>& b) {
154
229
int n = int (a.size ()), m = int (b.size ());
155
230
if (!n || !m) return {};
156
231
if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
0 commit comments