Skip to content

Commit 89d5d0a

Browse files
authored
Merge pull request #120 from atcoder/fft-optimize
Fft optimize
2 parents db08263 + f841a34 commit 89d5d0a

File tree

4 files changed

+225
-68
lines changed

4 files changed

+225
-68
lines changed

Diff for: atcoder/convolution.hpp

+143-68
Original file line numberDiff line numberDiff line change
@@ -14,95 +14,169 @@ namespace atcoder {
1414

1515
namespace internal {
1616

17+
template <class mint,
18+
int g = internal::primitive_root<mint::mod()>,
19+
internal::is_static_modint_t<mint>* = nullptr>
20+
struct fft_info {
21+
static constexpr int rank2 = bsf_constexpr(mint::mod() - 1);
22+
std::array<mint, rank2 + 1> root; // root[i]^(2^i) == 1
23+
std::array<mint, rank2 + 1> iroot; // root[i] * iroot[i] == 1
24+
25+
std::array<mint, std::max(0, rank2 - 2 + 1)> rate2;
26+
std::array<mint, std::max(0, rank2 - 2 + 1)> irate2;
27+
28+
std::array<mint, std::max(0, rank2 - 3 + 1)> rate3;
29+
std::array<mint, std::max(0, rank2 - 3 + 1)> irate3;
30+
31+
fft_info() {
32+
root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
33+
iroot[rank2] = root[rank2].inv();
34+
for (int i = rank2 - 1; i >= 0; i--) {
35+
root[i] = root[i + 1] * root[i + 1];
36+
iroot[i] = iroot[i + 1] * iroot[i + 1];
37+
}
38+
39+
{
40+
mint prod = 1, iprod = 1;
41+
for (int i = 0; i <= rank2 - 2; i++) {
42+
rate2[i] = root[i + 2] * prod;
43+
irate2[i] = iroot[i + 2] * iprod;
44+
prod *= iroot[i + 2];
45+
iprod *= root[i + 2];
46+
}
47+
}
48+
{
49+
mint prod = 1, iprod = 1;
50+
for (int i = 0; i <= rank2 - 3; i++) {
51+
rate3[i] = root[i + 3] * prod;
52+
irate3[i] = iroot[i + 3] * iprod;
53+
prod *= iroot[i + 3];
54+
iprod *= root[i + 3];
55+
}
56+
}
57+
}
58+
};
59+
1760
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
1861
void butterfly(std::vector<mint>& a) {
19-
static constexpr int g = internal::primitive_root<mint::mod()>;
2062
int n = int(a.size());
2163
int h = internal::ceil_pow2(n);
2264

23-
static bool first = true;
24-
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
25-
if (first) {
26-
first = false;
27-
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
28-
int cnt2 = bsf(mint::mod() - 1);
29-
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
30-
for (int i = cnt2; i >= 2; i--) {
31-
// e^(2^i) == 1
32-
es[i - 2] = e;
33-
ies[i - 2] = ie;
34-
e *= e;
35-
ie *= ie;
36-
}
37-
mint now = 1;
38-
for (int i = 0; i <= cnt2 - 2; i++) {
39-
sum_e[i] = es[i] * now;
40-
now *= ies[i];
41-
}
42-
}
43-
for (int ph = 1; ph <= h; ph++) {
44-
int w = 1 << (ph - 1), p = 1 << (h - ph);
45-
mint now = 1;
46-
for (int s = 0; s < w; s++) {
47-
int offset = s << (h - ph + 1);
48-
for (int i = 0; i < p; i++) {
49-
auto l = a[i + offset];
50-
auto r = a[i + offset + p] * now;
51-
a[i + offset] = l + r;
52-
a[i + offset + p] = l - r;
65+
static const fft_info<mint> info;
66+
67+
int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
68+
while (len < h) {
69+
if (h - len == 1) {
70+
int p = 1 << (h - len - 1);
71+
mint rot = 1;
72+
for (int s = 0; s < (1 << len); s++) {
73+
int offset = s << (h - len);
74+
for (int i = 0; i < p; i++) {
75+
auto l = a[i + offset];
76+
auto r = a[i + offset + p] * rot;
77+
a[i + offset] = l + r;
78+
a[i + offset + p] = l - r;
79+
}
80+
if (s + 1 != (1 << len))
81+
rot *= info.rate2[bsf(~(unsigned int)(s))];
82+
}
83+
len++;
84+
} else {
85+
// 4-base
86+
int p = 1 << (h - len - 2);
87+
mint rot = 1, imag = info.root[2];
88+
for (int s = 0; s < (1 << len); s++) {
89+
mint rot2 = rot * rot;
90+
mint rot3 = rot2 * rot;
91+
int offset = s << (h - len);
92+
for (int i = 0; i < p; i++) {
93+
auto mod2 = 1ULL * mint::mod() * mint::mod();
94+
auto a0 = 1ULL * a[i + offset].val();
95+
auto a1 = 1ULL * a[i + offset + p].val() * rot.val();
96+
auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val();
97+
auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val();
98+
auto a1na3imag =
99+
1ULL * mint(a1 + mod2 - a3).val() * imag.val();
100+
auto na2 = mod2 - a2;
101+
a[i + offset] = a0 + a2 + a1 + a3;
102+
a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
103+
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
104+
a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
105+
}
106+
if (s + 1 != (1 << len))
107+
rot *= info.rate3[bsf(~(unsigned int)(s))];
53108
}
54-
now *= sum_e[bsf(~(unsigned int)(s))];
109+
len += 2;
55110
}
56111
}
57112
}
58113

59114
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
60115
void butterfly_inv(std::vector<mint>& a) {
61-
static constexpr int g = internal::primitive_root<mint::mod()>;
62116
int n = int(a.size());
63117
int h = internal::ceil_pow2(n);
64118

65-
static bool first = true;
66-
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
67-
if (first) {
68-
first = false;
69-
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
70-
int cnt2 = bsf(mint::mod() - 1);
71-
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
72-
for (int i = cnt2; i >= 2; i--) {
73-
// e^(2^i) == 1
74-
es[i - 2] = e;
75-
ies[i - 2] = ie;
76-
e *= e;
77-
ie *= ie;
78-
}
79-
mint now = 1;
80-
for (int i = 0; i <= cnt2 - 2; i++) {
81-
sum_ie[i] = ies[i] * now;
82-
now *= es[i];
83-
}
84-
}
119+
static const fft_info<mint> info;
120+
121+
int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
122+
while (len) {
123+
if (len == 1) {
124+
int p = 1 << (h - len);
125+
mint irot = 1;
126+
for (int s = 0; s < (1 << (len - 1)); s++) {
127+
int offset = s << (h - len + 1);
128+
for (int i = 0; i < p; i++) {
129+
auto l = a[i + offset];
130+
auto r = a[i + offset + p];
131+
a[i + offset] = l + r;
132+
a[i + offset + p] =
133+
(unsigned long long)(mint::mod() + l.val() - r.val()) *
134+
irot.val();
135+
;
136+
}
137+
if (s + 1 != (1 << (len - 1)))
138+
irot *= info.irate2[bsf(~(unsigned int)(s))];
139+
}
140+
len--;
141+
} else {
142+
// 4-base
143+
int p = 1 << (h - len);
144+
mint irot = 1, iimag = info.iroot[2];
145+
for (int s = 0; s < (1 << (len - 2)); s++) {
146+
mint irot2 = irot * irot;
147+
mint irot3 = irot2 * irot;
148+
int offset = s << (h - len + 2);
149+
for (int i = 0; i < p; i++) {
150+
auto a0 = 1ULL * a[i + offset + 0 * p].val();
151+
auto a1 = 1ULL * a[i + offset + 1 * p].val();
152+
auto a2 = 1ULL * a[i + offset + 2 * p].val();
153+
auto a3 = 1ULL * a[i + offset + 3 * p].val();
154+
155+
auto a2na3iimag =
156+
1ULL *
157+
mint((mint::mod() + a2 - a3) * iimag.val()).val();
85158

86-
for (int ph = h; ph >= 1; ph--) {
87-
int w = 1 << (ph - 1), p = 1 << (h - ph);
88-
mint inow = 1;
89-
for (int s = 0; s < w; s++) {
90-
int offset = s << (h - ph + 1);
91-
for (int i = 0; i < p; i++) {
92-
auto l = a[i + offset];
93-
auto r = a[i + offset + p];
94-
a[i + offset] = l + r;
95-
a[i + offset + p] =
96-
(unsigned long long)(mint::mod() + l.val() - r.val()) *
97-
inow.val();
159+
a[i + offset] = a0 + a1 + a2 + a3;
160+
a[i + offset + 1 * p] =
161+
(a0 + (mint::mod() - a1) + a2na3iimag) * irot.val();
162+
a[i + offset + 2 * p] =
163+
(a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) *
164+
irot2.val();
165+
a[i + offset + 3 * p] =
166+
(a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) *
167+
irot3.val();
168+
}
169+
if (s + 1 != (1 << (len - 2)))
170+
irot *= info.irate3[bsf(~(unsigned int)(s))];
98171
}
99-
inow *= sum_ie[bsf(~(unsigned int)(s))];
172+
len -= 2;
100173
}
101174
}
102175
}
103176

104177
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
105-
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
178+
std::vector<mint> convolution_naive(const std::vector<mint>& a,
179+
const std::vector<mint>& b) {
106180
int n = int(a.size()), m = int(b.size());
107181
std::vector<mint> ans(n + m - 1);
108182
if (n < m) {
@@ -150,7 +224,8 @@ std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
150224
}
151225

152226
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
153-
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
227+
std::vector<mint> convolution(const std::vector<mint>& a,
228+
const std::vector<mint>& b) {
154229
int n = int(a.size()), m = int(b.size());
155230
if (!n || !m) return {};
156231
if (std::min(n, m) <= 60) return convolution_naive(a, b);

Diff for: atcoder/internal_bit.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ int ceil_pow2(int n) {
1717
return x;
1818
}
1919

20+
// @param n `1 <= n`
21+
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
22+
constexpr int bsf_constexpr(unsigned int n) {
23+
int x = 0;
24+
while (!(n & (1 << x))) x++;
25+
return x;
26+
}
27+
2028
// @param n `1 <= n`
2129
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
2230
int bsf(unsigned int n) {

Diff for: test/unittest/convolution_test.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,47 @@ TEST(ConvolutionTest, Conv18433) {
385385

386386
ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
387387
}
388+
389+
TEST(ConvolutionTest, Conv2) {
390+
std::vector<ll> empty = {};
391+
ASSERT_EQ(empty, convolution<2>(empty, empty));
392+
}
393+
394+
TEST(ConvolutionTest, Conv257) {
395+
const int MOD = 257;
396+
std::vector<ll> a(128), b(129);
397+
for (int i = 0; i < 128; i++) {
398+
a[i] = randint(0, MOD - 1);
399+
}
400+
for (int i = 0; i < 129; i++) {
401+
b[i] = randint(0, MOD - 1);
402+
}
403+
404+
ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
405+
}
406+
407+
TEST(ConvolutionTest, Conv2147483647) {
408+
const int MOD = 2147483647;
409+
using mint = static_modint<MOD>;
410+
std::vector<mint> a(1), b(2);
411+
for (int i = 0; i < 1; i++) {
412+
a[i] = randint(0, MOD - 1);
413+
}
414+
for (int i = 0; i < 2; i++) {
415+
b[i] = randint(0, MOD - 1);
416+
}
417+
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
418+
}
419+
420+
TEST(ConvolutionTest, Conv2130706433) {
421+
const int MOD = 2130706433;
422+
using mint = static_modint<MOD>;
423+
std::vector<mint> a(1024), b(1024);
424+
for (int i = 0; i < 1024; i++) {
425+
a[i] = randint(0, MOD - 1);
426+
}
427+
for (int i = 0; i < 1024; i++) {
428+
b[i] = randint(0, MOD - 1);
429+
}
430+
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
431+
}

Diff for: test/unittest/modint_test.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,29 @@ TEST(ModintTest, Mod1) {
100100
ASSERT_EQ(0, mint(true).val());
101101
}
102102

103+
TEST(ModintTest, ModIntMax) {
104+
modint::set_mod(INT32_MAX);
105+
for (int i = 0; i < 100; i++) {
106+
for (int j = 0; j < 100; j++) {
107+
ASSERT_EQ((modint(i) * modint(j)).val(), i * j);
108+
}
109+
}
110+
ASSERT_EQ((modint(1234) + modint(5678)).val(), 1234 + 5678);
111+
ASSERT_EQ((modint(1234) - modint(5678)).val(), INT32_MAX - 5678 + 1234);
112+
ASSERT_EQ((modint(1234) * modint(5678)).val(), 1234 * 5678);
113+
114+
using mint = static_modint<INT32_MAX>;
115+
for (int i = 0; i < 100; i++) {
116+
for (int j = 0; j < 100; j++) {
117+
ASSERT_EQ((mint(i) * mint(j)).val(), i * j);
118+
}
119+
}
120+
ASSERT_EQ((mint(1234) + mint(5678)).val(), 1234 + 5678);
121+
ASSERT_EQ((mint(1234) - mint(5678)).val(), INT32_MAX - 5678 + 1234);
122+
ASSERT_EQ((mint(1234) * mint(5678)).val(), 1234 * 5678);
123+
ASSERT_EQ((mint(INT32_MAX) + mint(INT32_MAX)).val(), 0);
124+
}
125+
103126
#ifndef _MSC_VER
104127

105128
TEST(ModintTest, Int128) {
@@ -158,6 +181,13 @@ TEST(ModintTest, Inv) {
158181
int x = modint(i).inv().val();
159182
ASSERT_EQ(1, (ll(x) * i) % 1'000'000'008);
160183
}
184+
185+
modint::set_mod(INT32_MAX);
186+
for (int i = 1; i < 100000; i++) {
187+
if (gcd(i, INT32_MAX) != 1) continue;
188+
int x = modint(i).inv().val();
189+
ASSERT_EQ(1, (ll(x) * i) % INT32_MAX);
190+
}
161191
}
162192

163193
TEST(ModintTest, ConstUsage) {

0 commit comments

Comments
 (0)