@@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
199
199
std::vector<mint> convolution_fft (std::vector<mint> a, std::vector<mint> b) {
200
200
int n = int (a.size ()), m = int (b.size ());
201
201
int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
202
- assert (mint::mod () % z == 1 );
203
202
a.resize (z);
204
203
internal::butterfly (a);
205
204
b.resize (z);
@@ -220,15 +219,22 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
220
219
std::vector<mint> convolution (std::vector<mint>&& a, std::vector<mint>&& b) {
221
220
int n = int (a.size ()), m = int (b.size ());
222
221
if (!n || !m) return {};
222
+
223
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
224
+ assert ((mint::mod () - 1 ) % z == 0 );
225
+
223
226
if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
224
227
return internal::convolution_fft (a, b);
225
228
}
226
-
227
229
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
228
230
std::vector<mint> convolution (const std::vector<mint>& a,
229
231
const std::vector<mint>& b) {
230
232
int n = int (a.size ()), m = int (b.size ());
231
233
if (!n || !m) return {};
234
+
235
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
236
+ assert ((mint::mod () - 1 ) % z == 0 );
237
+
232
238
if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
233
239
return internal::convolution_fft (a, b);
234
240
}
@@ -241,6 +247,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
241
247
if (!n || !m) return {};
242
248
243
249
using mint = static_modint<mod>;
250
+
251
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
252
+ assert ((mint::mod () - 1 ) % z == 0 );
253
+
244
254
std::vector<mint> a2 (n), b2 (m);
245
255
for (int i = 0 ; i < n; i++) {
246
256
a2[i] = mint (a[i]);
@@ -280,7 +290,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
280
290
static_assert (MOD1 % (1ull << MAX_AB_BIT) == 1 , " MOD1 isn't enough to support an array length of 2^24." );
281
291
static_assert (MOD2 % (1ull << MAX_AB_BIT) == 1 , " MOD2 isn't enough to support an array length of 2^24." );
282
292
static_assert (MOD3 % (1ull << MAX_AB_BIT) == 1 , " MOD3 isn't enough to support an array length of 2^24." );
283
- assert (a. size () + b. size () - 1 <= (1ull << MAX_AB_BIT));
293
+ assert (n + m - 1 <= (1 << MAX_AB_BIT));
284
294
285
295
auto c1 = convolution<MOD1>(a, b);
286
296
auto c2 = convolution<MOD2>(a, b);
0 commit comments