diff --git a/benchmark/benchmark_comparison_non_stream_WWR.jl b/benchmark/benchmark_comparison_non_stream_WWR.jl index 885e044..85497a4 100644 --- a/benchmark/benchmark_comparison_non_stream_WWR.jl +++ b/benchmark/benchmark_comparison_non_stream_WWR.jl @@ -9,50 +9,68 @@ using CairoMakie ## sequential ## ################ +using AliasTables +using Random, StatsBase, Distributions + function weighted_reservoir_sample(rng, a, ws, n) return shuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1]) end function weighted_reservoir_sample_seq(rng, a, ws, n) - m = min(length(a), n) - view_w_f_n = @view(ws[1:m]) - w_sum = sum(view_w_f_n) - reservoir = sample(rng, @view(a[1:m]), Weights(view_w_f_n, w_sum), n) - length(a) <= n && return reservoir, w_sum - w_skip = @inline skip(rng, w_sum, n) + @assert length(a) >= n + view_w_f_n = @view(ws[1:n]) + reservoir = rand(rng, AliasTable(view_w_f_n), n) + w_sum, w_sum_chunk = sum(view_w_f_n), 0.0 + w_skip_chunk = skip(rng, w_sum, n) @inbounds for i in n+1:length(a) w_el = ws[i] - w_sum += w_el - if w_sum > w_skip - p = w_el/w_sum - q = 1-p - z = exp((n-4)*log1p(-p)) - t = rand(rng, Uniform(z*q*q*q*q,1.0)) - k = @inline choose(n, p, q, t, z) - @inbounds for j in 1:k - r = rand(rng, j:n) - reservoir[r], reservoir[j] = reservoir[j], a[i] - end - w_skip = @inline skip(rng, w_sum, n) + w_sum_chunk += w_el + if w_sum_chunk > w_skip_chunk + w_sum_new = w_sum + w_sum_chunk + p = w_el/w_sum_new + k = choose(rng, n, p) + if k == 1 + reservoir[rand(rng, 1:n)] = a[i] + else + for j in 1:k + reservoir[rand(rng, j:n)], reservoir[j] = reservoir[j], a[i] + end + end + w_skip_chunk = skip(rng, w_sum_new, n) + w_sum, w_sum_chunk = w_sum_new, 0.0 end end - return reservoir, w_sum + return reservoir, w_sum + w_sum_chunk end -function skip(rng, w_sum::AbstractFloat, n) - k = exp(-randexp(rng)/n) - return w_sum/k +function skip(rng, w_sum, n) + z = exp(-randexp(rng)/n) + return w_sum*((1-z)/z) end -function choose(n, p, q, t, z) - x = z*q*q*q*(q + n*p) +@inline function choose(rng, n, p) + z = exp(n*log1p(-p)) + t = rand(rng, Uniform(z, 1.0)) + s = n*p + q = 1-p + x = z + z*s/q x > t && return 1 - x += n*p*(n-1)*p*z*q*q/2 + s *= (n-1)*p + q *= 1-p + x += (s*z/q)/2 x > t && return 2 - x += n*p*(n-1)*p*(n-2)*p*z*q/6 + s *= (n-2)*p + q *= 1-p + x += (s*z/q)/6 x > t && return 3 - x += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 + s *= (n-3)*p + q *= 1-p + x += (s*z/q)/24 x > t && return 4 + s *= (n-4)*p + q *= 1-p + x += (s*z/q)/120 + x > t && return 5 return quantile(Binomial(n, p), t) end diff --git a/benchmark/benchmark_comparison_stream_WWR.jl b/benchmark/benchmark_comparison_stream_WWR.jl index 307d86e..66c69e3 100644 --- a/benchmark/benchmark_comparison_stream_WWR.jl +++ b/benchmark/benchmark_comparison_stream_WWR.jl @@ -58,7 +58,7 @@ function update_state!(s::SampleMultiAlgAExpJWR, w) end function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR) - return shuffle!(s.rng, last.(s.value.valtree)) + return StreamSampling.faster_shuffle!(s.rng, last.(s.value.valtree)) end a = Iterators.filter(x -> x != 1, 1:10^8) diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index 6547c9f..3584129 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -185,7 +185,7 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me else m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR() s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m)) - return ordered ? s : shuffle!(rng, s) + return ordered ? s : faster_shuffle!(rng, s) end end function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter)) @@ -206,7 +206,7 @@ function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype. vs[i], ps[i] = update_all_p!(s, iters[i]) end ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) + return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) end function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = infer_eltype.(iters)) n_it = length(iters) @@ -217,7 +217,7 @@ function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i]) end ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) + return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) end function update_all!(s, iter) @@ -236,13 +236,13 @@ function update_all!(s, iter, ordered::Bool) for x in iter s = fit!(s, x) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) + return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s)) end function update_all!(s, iter, ordered, wv) for x in iter s = fit!(s, x, wv(x)) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) + return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s)) end function update_all_p!(s, iter) diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 5c8c729..c8707b9 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -144,4 +144,13 @@ end Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype() Base.eltype(::SeqSampleIter) = Int Base.IteratorSize(::SeqSampleIter) = Base.HasLength() -Base.length(s::SeqSampleIter) = s.n \ No newline at end of file +Base.length(s::SeqSampleIter) = s.n + +function faster_shuffle!(rng::AbstractRNG, vec::AbstractVector) + for i in 2:length(vec) + endi = (i-1) % UInt + j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1 + vec[i], vec[j] = vec[j], vec[i] + end + vec +end diff --git a/src/SortedSamplingSingle.jl b/src/SortedSamplingSingle.jl index 26bd98e..1bb551b 100644 --- a/src/SortedSamplingSingle.jl +++ b/src/SortedSamplingSingle.jl @@ -1,6 +1,6 @@ function sorted_sample_single(rng, iter) - k = rand(rng, 1:length(iter)) + k = rand(rng, Random.Sampler(rng, 1:length(iter), Val(1))) for (i, el) in enumerate(iter) i == k && return el end diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 947e074..bed1e78 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -73,7 +73,7 @@ end @inbounds s.value[s.seen_k] = el return s end - j = rand(s.rng, 1:s.seen_k) + j = @inline rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1))) if j <= n @inbounds s.value[j] = el update_order!(s, j) @@ -91,7 +91,7 @@ end return s end if s.skip_k < s.seen_k - j = rand(s.rng, 1:n) + j = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) @inbounds s.value[j] = el update_order!(s, j) s = @inline recompute_skip!(s, n) @@ -114,13 +114,17 @@ end end if s.skip_k < s.seen_k p = 1/s.seen_k - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) - k = @inline choose(n, p, c, z) - @inbounds for j in 1:k - r = rand(s.rng, j:n) - s.value[r], s.value[j] = s.value[j], el - update_order_multi!(s, r, j) + k = @inline choose(s.rng, n, p) + if k == 1 + r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) + s.value[r] = el + update_order_single!(s, r) + else + @inbounds for j in 1:k + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) + s.value[r], s.value[j] = s.value[j], el + update_order_multi!(s, r, j) + end end s = @inline recompute_skip!(s, n) end @@ -167,18 +171,30 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) return s end -function choose(n, p, c, z) +@inline function choose(rng, n, p) + z = exp(n*log1p(-p)) + t = rand(rng, Uniform(z, 1.0)) + s = n*p q = 1-p - k = z*q*q*q*(q + n*p) - k > c && return 1 - k += n*p*(n-1)*p*z*q*q/2 - k > c && return 2 - k += n*p*(n-1)*p*(n-2)*p*z*q/6 - k > c && return 3 - k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 - k > c && return 4 - b = Binomial(n, p) - return quantile(b, c) + x = z + z*s/q + x > t && return 1 + s *= (n-1)*p + q *= 1-p + x += (s*z/q)/2 + x > t && return 2 + s *= (n-2)*p + q *= 1-p + x += (s*z/q)/6 + x > t && return 3 + s *= (n-3)*p + q *= 1-p + x += (s*z/q)/24 + x > t && return 4 + s *= (n-4)*p + q *= 1-p + x += (s*z/q)/120 + x > t && return 5 + return quantile(Binomial(n, p), t) end update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 9495a2d..76c7da6 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -139,13 +139,17 @@ end end if s.skip_w <= s.state p = w/s.state - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0)) - k = @inline choose(n, p, c, z) - @inbounds for j in 1:k - r = rand(s.rng, j:n) - s.value[r], s.value[j] = s.value[j], el - update_order_multi!(s, r, j) + k = @inline choose(s.rng, n, p) + if k == 1 + r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) + s.value[r] = el + update_order_single!(s, r) + else + @inbounds for j in 1:k + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) + s.value[r], s.value[j] = s.value[j], el + update_order_multi!(s, r, j) + end end s = @inline recompute_skip!(s, n) end