Skip to content

Commit cbfae94

Browse files
committed
Merge branch 'HotFix_multiprocess'
2 parents 07f9dbc + 0f39eb5 commit cbfae94

File tree

5 files changed

+167
-34
lines changed

5 files changed

+167
-34
lines changed

.gitlab-ci.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ SegmentTree.cpp:
122122
- $CXX -o test/SegmentTree test/SegmentTree.cpp
123123
- test/SegmentTree
124124

125+
126+
cpp_bench:
127+
image: *dev_image
128+
stage: bench_mark_test
129+
needs:
130+
- job: ReplayBuffer.cpp
131+
artifacts: false
132+
- job: SegmentTree.cpp
133+
artifacts: false
134+
script:
135+
- $CXX -o test/segmenttree_bench test/segmenttree_bench.cpp
136+
- test/segmenttree_bench
137+
- test/segmenttree_bench
138+
- test/segmenttree_bench
139+
- test/segmenttree_bench
140+
- test/segmenttree_bench
141+
142+
125143
ReplayBuffer:
126144
<<: *py_setup
127145
script:

cpprb/ReplayBuffer.hh

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ namespace ymd {
376376
sum{PowerOf2(buffer_size),[](auto a,auto b){ return a+b; },
377377
Priority{0},
378378
sum_ptr,sum_anychanged,initialize},
379-
min{PowerOf2(buffer_size),[](Priority a,Priority b){ return std::min(a,b); },
379+
min{PowerOf2(buffer_size),[](Priority a,Priority b){ return std::min(a,b); },
380380
std::numeric_limits<Priority>::max(),
381381
min_ptr,min_anychanged,initialize},
382382
g{std::random_device{}()},
@@ -486,13 +486,6 @@ namespace ymd {
486486
void set_eps(Priority eps){
487487
this->eps = eps;
488488
}
489-
490-
void weak_update_changed(){
491-
if constexpr (MultiThread) {
492-
sum.weak_update_changed();
493-
min.weak_update_changed();
494-
}
495-
}
496489
};
497490

498491
template<typename Priority>

cpprb/SegmentTree.hh

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ namespace ymd {
6969
return tmp != buffer[i];
7070
}
7171

72-
void update_all(){
72+
void update_init(){
7373
for(std::size_t i = access_index(0) -1, end = -1; i != end; --i){
7474
update_buffer(i);
7575
}
@@ -78,6 +78,22 @@ namespace ymd {
7878
}
7979
}
8080

81+
void update_all(){
82+
constexpr const std::size_t zero = 0;
83+
const auto end = parent(access_index(buffer_size-1))+1;
84+
for(auto i = parent(access_index(0)); i != end; ++i){
85+
auto updated = update_buffer(i);
86+
auto _i = i;
87+
while((_i != zero) && updated){
88+
_i = parent(_i);
89+
updated = update_buffer(_i);
90+
}
91+
}
92+
if constexpr (MultiThread){
93+
any_changed->store(false,std::memory_order_release);
94+
}
95+
}
96+
8197
public:
8298
SegmentTree(std::size_t n,F f, T v = T{0},
8399
T* buffer_ptr = nullptr,
@@ -105,7 +121,7 @@ namespace ymd {
105121
if(initialize){
106122
std::fill_n(buffer+access_index(0),n,v);
107123

108-
update_all();
124+
update_init();
109125
}
110126
}
111127
SegmentTree(): SegmentTree{2,[](auto a,auto b){ return a+b; }} {}
@@ -142,8 +158,6 @@ namespace ymd {
142158
constexpr const std::size_t zero = 0;
143159
if(zero == max){ max = buffer_size; }
144160

145-
std::set<std::size_t> will_update{};
146-
147161
if constexpr (MultiThread){
148162
if(N){ any_changed->store(true,std::memory_order_release); }
149163
}
@@ -155,24 +169,17 @@ namespace ymd {
155169
if constexpr (!MultiThread){
156170
for(auto n = std::size_t(0); n < copy_N; ++n){
157171
auto _i = access_index(i+n);
158-
if(_i != 0){
159-
will_update.insert(parent(_i));
172+
auto updated = true;
173+
while((_i != zero) && updated){
174+
_i = parent(_i);
175+
updated = update_buffer(_i);
160176
}
161177
}
162178
}
163179

164180
N = (N > copy_N) ? N - copy_N: zero;
165181
i = zero;
166182
}
167-
168-
if constexpr (!MultiThread) {
169-
while(!will_update.empty()){
170-
i = *(will_update.rbegin());
171-
auto updated = update_buffer(i);
172-
will_update.erase(i);
173-
if(i && updated){ will_update.insert(parent(i)); }
174-
}
175-
}
176183
}
177184

178185
void set(std::size_t i,T v,std::size_t N,std::size_t max = std::size_t(0)){
@@ -190,7 +197,8 @@ namespace ymd {
190197
}
191198

192199
auto largest_region_index(std::function<bool(T)> condition,
193-
std::size_t n=std::size_t(0)) {
200+
std::size_t n=std::size_t(0),
201+
T init = T{0}) {
194202
// max index of reduce( [0,index) ) -> true
195203

196204
constexpr const std::size_t zero = 0;
@@ -203,21 +211,30 @@ namespace ymd {
203211
}
204212
}
205213

206-
std::size_t min = zero;
207-
auto max = (zero != n) ? n: buffer_size;
214+
if(n == zero){ n = buffer_size; }
215+
auto b = zero;
216+
217+
if(condition(buffer[b])){ return n-1; }
208218

209-
auto index = (min + max)/two;
219+
auto min = zero;
220+
auto max = buffer_size;
221+
auto cond = condition;
222+
auto red = init;
210223

211224
while(max - min > one){
212-
if( condition(_reduce(zero,index,zero,zero,buffer_size)) ){
213-
min = index;
225+
auto b_left = child_left(b);
226+
if(cond(buffer[b_left])){
227+
min = (min + max) / two;
228+
red = f(red, buffer[b_left]);
229+
cond = [=](auto v){ return condition(f(red,v)); };
230+
b = child_right(b);
214231
}else{
215-
max = index;
232+
max = (min + max) / two;
233+
b = b_left;
216234
}
217-
index = (min + max)/two;
218235
}
219236

220-
return index;
237+
return std::min(min, n-1);
221238
}
222239

223240
void clear(T v = T{0}){

test/PyReplayBuffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_5_done(self):
9595
class TestPrioritizedBase:
9696
def test_weights(self):
9797
self._check_ndarray(self.s['weights'],1,(self.batch_size,),"weights")
98-
for w in self.s['weights']:
99-
self.assertAlmostEqual(w,1.0)
98+
np.testing.assert_allclose(self.s["weights"],
99+
np.full_like(self.s["weights"], 1.0))
100100

101101
def test_indexes(self):
102102
self._check_ndarray(self.s['indexes'],1,(self.batch_size,),"indexes")

test/segmenttree_bench.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <algorithm>
2+
#include <iostream>
3+
#include <iterator>
4+
#include <chrono>
5+
#include <vector>
6+
7+
#include <SegmentTree.hh>
8+
#include <ReplayBuffer.hh>
9+
10+
using PER = ymd::CppPrioritizedSampler<float>;
11+
using MPPER = ymd::CppThreadSafePrioritizedSampler<float>;
12+
13+
auto bench = [](auto&& F, auto n, auto fmt=""){
14+
auto t1 = std::chrono::high_resolution_clock::now();
15+
for(auto i=0ul; i < n; ++i){ F(); }
16+
auto t2 = std::chrono::high_resolution_clock::now();
17+
18+
std::cout << fmt
19+
<< std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count()
20+
<< std::endl;
21+
};
22+
23+
24+
int main(int argc, char** argv){
25+
constexpr const auto buffer_size = 1000000ul;
26+
constexpr const auto size = ymd::PowerOf2(buffer_size);
27+
28+
auto sum = ymd::SegmentTree<float>(size, [](auto a, auto b){ return a+b; });
29+
auto sum2 = ymd::SegmentTree<float,true>(size, [](auto a, auto b){ return a+b; });
30+
31+
bench([&,i=0, j=0]() mutable { sum.set(i++, j++); }, 10000, "sum1.set A: ");
32+
bench([&,i=100]() mutable { sum.set(100*(i++),
33+
[j=0]()mutable{ return j++; },
34+
100,
35+
buffer_size); }, 100, "sum1.set B: ");
36+
bench([&,i=20]() mutable { sum.set(1000*(i++),
37+
[j=0]()mutable{ return j++; },
38+
1000,
39+
buffer_size); }, 10, "sum1.set C: ");
40+
bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red A: ");
41+
bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red B: ");
42+
bench([&,i=0]() mutable {
43+
sum.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000);
44+
},
45+
10000, "sum1.lridx: ");
46+
47+
bench([&,i=0, j=0]() mutable { sum2.set(i++, j++); }, 10000, "sum2.set A: ");
48+
bench([&,i=100]() mutable { sum2.set(100*(i++),
49+
[j=0]()mutable{ return j++; },
50+
100,
51+
buffer_size); }, 100, "sum2.set B: ");
52+
bench([&,i=20]() mutable { sum2.set(1000*(i++),
53+
[j=0]()mutable{ return j++; },
54+
1000,
55+
buffer_size); }, 10, "sum2.set C: ");
56+
bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red A: ");
57+
bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red B: ");
58+
bench([&,i=0]() mutable {
59+
sum2.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000);
60+
},
61+
10000, "sum2.lridx: ");
62+
63+
std::cout << sum.get(1) << " " << sum2.get(1) << std::endl;
64+
std::cout << sum.get(10001) << " " << sum2.get(10001) << std::endl;
65+
66+
//
67+
68+
constexpr const auto alpha = 0.5, beta = 0.4;
69+
auto per = PER(buffer_size, alpha);
70+
auto mpper = MPPER(buffer_size, alpha);
71+
72+
auto p = std::vector<float>{};
73+
p.reserve(10000);
74+
std::generate_n(std::back_inserter(p), 10000,
75+
[i=0]() mutable { return 0.02*(i++ % 321); });
76+
77+
auto indexes = std::vector<size_t>{};
78+
indexes.reserve(32);
79+
auto weights = std::vector<float>{};
80+
weights.reserve(32);
81+
82+
bench([&, i=0,j=0]() mutable { per.set_priorities(i++, 0.02*(j++ % 321)); },
83+
10000, " PER.add1: ");
84+
bench([&, i=100,j=0]() mutable {
85+
per.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size);
86+
},
87+
100, " PER.addN: ");
88+
bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); },
89+
1, " PER.smpA: ");
90+
bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); },
91+
1, " PER.smpB: ");
92+
93+
bench([&, i=0,j=0]() mutable { mpper.set_priorities(i++, 0.02*(j++ % 321)); },
94+
10000, "MPPER.add1: ");
95+
bench([&, i=100,j=0]() mutable {
96+
mpper.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size);
97+
},
98+
100, "MPPER.addN: ");
99+
bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); },
100+
1, "MPPER.smpA: ");
101+
bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); },
102+
1, "MPPER.smpB: ");
103+
104+
return 0;
105+
}

0 commit comments

Comments
 (0)