Skip to content

Commit 198f33d

Browse files
authored
Merge pull request #159 from yosupo06/patch/conv
fix convolution constraint
2 parents 47b2ec4 + 1a2785d commit 198f33d

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

atcoder/convolution.hpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
199199
std::vector<mint> convolution_fft(std::vector<mint> a, std::vector<mint> b) {
200200
int n = int(a.size()), m = int(b.size());
201201
int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
202-
assert(mint::mod() % z == 1);
203202
a.resize(z);
204203
internal::butterfly(a);
205204
b.resize(z);
@@ -220,15 +219,22 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
220219
std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
221220
int n = int(a.size()), m = int(b.size());
222221
if (!n || !m) return {};
222+
223+
int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
224+
assert((mint::mod() - 1) % z == 0);
225+
223226
if (std::min(n, m) <= 60) return convolution_naive(a, b);
224227
return internal::convolution_fft(a, b);
225228
}
226-
227229
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
228230
std::vector<mint> convolution(const std::vector<mint>& a,
229231
const std::vector<mint>& b) {
230232
int n = int(a.size()), m = int(b.size());
231233
if (!n || !m) return {};
234+
235+
int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
236+
assert((mint::mod() - 1) % z == 0);
237+
232238
if (std::min(n, m) <= 60) return convolution_naive(a, b);
233239
return internal::convolution_fft(a, b);
234240
}
@@ -241,6 +247,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
241247
if (!n || !m) return {};
242248

243249
using mint = static_modint<mod>;
250+
251+
int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
252+
assert((mint::mod() - 1) % z == 0);
253+
244254
std::vector<mint> a2(n), b2(m);
245255
for (int i = 0; i < n; i++) {
246256
a2[i] = mint(a[i]);
@@ -280,7 +290,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
280290
static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1, "MOD1 isn't enough to support an array length of 2^24.");
281291
static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1, "MOD2 isn't enough to support an array length of 2^24.");
282292
static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1, "MOD3 isn't enough to support an array length of 2^24.");
283-
assert(a.size() + b.size() - 1 <= (1ull << MAX_AB_BIT));
293+
assert(n + m - 1 <= (1 << MAX_AB_BIT));
284294

285295
auto c1 = convolution<MOD1>(a, b);
286296
auto c2 = convolution<MOD2>(a, b);

0 commit comments

Comments
 (0)