|
| 1 | +#include <algorithm> |
| 2 | +#include <iostream> |
| 3 | +#include <iterator> |
| 4 | +#include <chrono> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +#include <SegmentTree.hh> |
| 8 | +#include <ReplayBuffer.hh> |
| 9 | + |
| 10 | +using PER = ymd::CppPrioritizedSampler<float>; |
| 11 | +using MPPER = ymd::CppThreadSafePrioritizedSampler<float>; |
| 12 | + |
| 13 | +auto bench = [](auto&& F, auto n, auto fmt=""){ |
| 14 | + auto t1 = std::chrono::high_resolution_clock::now(); |
| 15 | + for(auto i=0ul; i < n; ++i){ F(); } |
| 16 | + auto t2 = std::chrono::high_resolution_clock::now(); |
| 17 | + |
| 18 | + std::cout << fmt |
| 19 | + << std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count() |
| 20 | + << std::endl; |
| 21 | + }; |
| 22 | + |
| 23 | + |
| 24 | +int main(int argc, char** argv){ |
| 25 | + constexpr const auto buffer_size = 1000000ul; |
| 26 | + constexpr const auto size = ymd::PowerOf2(buffer_size); |
| 27 | + |
| 28 | + auto sum = ymd::SegmentTree<float>(size, [](auto a, auto b){ return a+b; }); |
| 29 | + auto sum2 = ymd::SegmentTree<float,true>(size, [](auto a, auto b){ return a+b; }); |
| 30 | + |
| 31 | + bench([&,i=0, j=0]() mutable { sum.set(i++, j++); }, 10000, "sum1.set A: "); |
| 32 | + bench([&,i=100]() mutable { sum.set(100*(i++), |
| 33 | + [j=0]()mutable{ return j++; }, |
| 34 | + 100, |
| 35 | + buffer_size); }, 100, "sum1.set B: "); |
| 36 | + bench([&,i=20]() mutable { sum.set(1000*(i++), |
| 37 | + [j=0]()mutable{ return j++; }, |
| 38 | + 1000, |
| 39 | + buffer_size); }, 10, "sum1.set C: "); |
| 40 | + bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red A: "); |
| 41 | + bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red B: "); |
| 42 | + bench([&,i=0]() mutable { |
| 43 | + sum.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000); |
| 44 | + }, |
| 45 | + 10000, "sum1.lridx: "); |
| 46 | + |
| 47 | + bench([&,i=0, j=0]() mutable { sum2.set(i++, j++); }, 10000, "sum2.set A: "); |
| 48 | + bench([&,i=100]() mutable { sum2.set(100*(i++), |
| 49 | + [j=0]()mutable{ return j++; }, |
| 50 | + 100, |
| 51 | + buffer_size); }, 100, "sum2.set B: "); |
| 52 | + bench([&,i=20]() mutable { sum2.set(1000*(i++), |
| 53 | + [j=0]()mutable{ return j++; }, |
| 54 | + 1000, |
| 55 | + buffer_size); }, 10, "sum2.set C: "); |
| 56 | + bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red A: "); |
| 57 | + bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red B: "); |
| 58 | + bench([&,i=0]() mutable { |
| 59 | + sum2.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000); |
| 60 | + }, |
| 61 | + 10000, "sum2.lridx: "); |
| 62 | + |
| 63 | + std::cout << sum.get(1) << " " << sum2.get(1) << std::endl; |
| 64 | + std::cout << sum.get(10001) << " " << sum2.get(10001) << std::endl; |
| 65 | + |
| 66 | + // |
| 67 | + |
| 68 | + constexpr const auto alpha = 0.5, beta = 0.4; |
| 69 | + auto per = PER(buffer_size, alpha); |
| 70 | + auto mpper = MPPER(buffer_size, alpha); |
| 71 | + |
| 72 | + auto p = std::vector<float>{}; |
| 73 | + p.reserve(10000); |
| 74 | + std::generate_n(std::back_inserter(p), 10000, |
| 75 | + [i=0]() mutable { return 0.02*(i++ % 321); }); |
| 76 | + |
| 77 | + auto indexes = std::vector<size_t>{}; |
| 78 | + indexes.reserve(32); |
| 79 | + auto weights = std::vector<float>{}; |
| 80 | + weights.reserve(32); |
| 81 | + |
| 82 | + bench([&, i=0,j=0]() mutable { per.set_priorities(i++, 0.02*(j++ % 321)); }, |
| 83 | + 10000, " PER.add1: "); |
| 84 | + bench([&, i=100,j=0]() mutable { |
| 85 | + per.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size); |
| 86 | + }, |
| 87 | + 100, " PER.addN: "); |
| 88 | + bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); }, |
| 89 | + 1, " PER.smpA: "); |
| 90 | + bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); }, |
| 91 | + 1, " PER.smpB: "); |
| 92 | + |
| 93 | + bench([&, i=0,j=0]() mutable { mpper.set_priorities(i++, 0.02*(j++ % 321)); }, |
| 94 | + 10000, "MPPER.add1: "); |
| 95 | + bench([&, i=100,j=0]() mutable { |
| 96 | + mpper.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size); |
| 97 | + }, |
| 98 | + 100, "MPPER.addN: "); |
| 99 | + bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); }, |
| 100 | + 1, "MPPER.smpA: "); |
| 101 | + bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); }, |
| 102 | + 1, "MPPER.smpB: "); |
| 103 | + |
| 104 | + return 0; |
| 105 | +} |
0 commit comments