From ba13d034e0d083290ec12fee80509e02419de557 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:22:58 +0100 Subject: [PATCH 01/16] Refactor algorithms to improve numerical stability --- .../benchmark_comparison_non_stream_WWR.jl | 72 ++++++++++++------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/benchmark/benchmark_comparison_non_stream_WWR.jl b/benchmark/benchmark_comparison_non_stream_WWR.jl index 885e044..85497a4 100644 --- a/benchmark/benchmark_comparison_non_stream_WWR.jl +++ b/benchmark/benchmark_comparison_non_stream_WWR.jl @@ -9,50 +9,68 @@ using CairoMakie ## sequential ## ################ +using AliasTables +using Random, StatsBase, Distributions + function weighted_reservoir_sample(rng, a, ws, n) return shuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1]) end function weighted_reservoir_sample_seq(rng, a, ws, n) - m = min(length(a), n) - view_w_f_n = @view(ws[1:m]) - w_sum = sum(view_w_f_n) - reservoir = sample(rng, @view(a[1:m]), Weights(view_w_f_n, w_sum), n) - length(a) <= n && return reservoir, w_sum - w_skip = @inline skip(rng, w_sum, n) + @assert length(a) >= n + view_w_f_n = @view(ws[1:n]) + reservoir = rand(rng, AliasTable(view_w_f_n), n) + w_sum, w_sum_chunk = sum(view_w_f_n), 0.0 + w_skip_chunk = skip(rng, w_sum, n) @inbounds for i in n+1:length(a) w_el = ws[i] - w_sum += w_el - if w_sum > w_skip - p = w_el/w_sum - q = 1-p - z = exp((n-4)*log1p(-p)) - t = rand(rng, Uniform(z*q*q*q*q,1.0)) - k = @inline choose(n, p, q, t, z) - @inbounds for j in 1:k - r = rand(rng, j:n) - reservoir[r], reservoir[j] = reservoir[j], a[i] - end - w_skip = @inline skip(rng, w_sum, n) + w_sum_chunk += w_el + if w_sum_chunk > w_skip_chunk + w_sum_new = w_sum + w_sum_chunk + p = w_el/w_sum_new + k = choose(rng, n, p) + if k == 1 + reservoir[rand(rng, 1:n)] = a[i] + else + for j in 1:k + reservoir[rand(rng, j:n)], reservoir[j] = reservoir[j], a[i] + end + end + w_skip_chunk = skip(rng, w_sum_new, n) + w_sum, w_sum_chunk = w_sum_new, 0.0 end end - return reservoir, w_sum + return reservoir, w_sum + w_sum_chunk end -function skip(rng, w_sum::AbstractFloat, n) - k = exp(-randexp(rng)/n) - return w_sum/k +function skip(rng, w_sum, n) + z = exp(-randexp(rng)/n) + return w_sum*((1-z)/z) end -function choose(n, p, q, t, z) - x = z*q*q*q*(q + n*p) +@inline function choose(rng, n, p) + z = exp(n*log1p(-p)) + t = rand(rng, Uniform(z, 1.0)) + s = n*p + q = 1-p + x = z + z*s/q x > t && return 1 - x += n*p*(n-1)*p*z*q*q/2 + s *= (n-1)*p + q *= 1-p + x += (s*z/q)/2 x > t && return 2 - x += n*p*(n-1)*p*(n-2)*p*z*q/6 + s *= (n-2)*p + q *= 1-p + x += (s*z/q)/6 x > t && return 3 - x += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 + s *= (n-3)*p + q *= 1-p + x += (s*z/q)/24 x > t && return 4 + s *= (n-4)*p + q *= 1-p + x += (s*z/q)/120 + x > t && return 5 return quantile(Binomial(n, p), t) end From dc992651872184da183a1e3e3007d83c59ef9030 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Thu, 6 Mar 2025 16:04:27 +0100 Subject: [PATCH 02/16] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 52 ++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 947e074..508c933 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -114,13 +114,17 @@ end end if s.skip_k < s.seen_k p = 1/s.seen_k - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) - k = @inline choose(n, p, c, z) - @inbounds for j in 1:k - r = rand(s.rng, j:n) - s.value[r], s.value[j] = s.value[j], el - update_order_multi!(s, r, j) + k = @inline choose(rng, n, p) + if k == 1 + r = rand(s.rng, 1:n) + s.value[r] = el + update_order_single!(s, r) + else + @inbounds for j in 1:k + r = rand(s.rng, j:n) + s.value[r], s.value[j] = s.value[j], el + update_order_multi!(s, r, j) + end end s = @inline recompute_skip!(s, n) end @@ -167,18 +171,30 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) return s end -function choose(n, p, c, z) +@inline function choose(rng, n, p) + z = exp(n*log1p(-p)) + t = rand(rng, Uniform(z, 1.0)) + s = n*p q = 1-p - k = z*q*q*q*(q + n*p) - k > c && return 1 - k += n*p*(n-1)*p*z*q*q/2 - k > c && return 2 - k += n*p*(n-1)*p*(n-2)*p*z*q/6 - k > c && return 3 - k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 - k > c && return 4 - b = Binomial(n, p) - return quantile(b, c) + x = z + z*s/q + x > t && return 1 + s *= (n-1)*p + q *= 1-p + x += (s*z/q)/2 + x > t && return 2 + s *= (n-2)*p + q *= 1-p + x += (s*z/q)/6 + x > t && return 3 + s *= (n-3)*p + q *= 1-p + x += (s*z/q)/24 + x > t && return 4 + s *= (n-4)*p + q *= 1-p + x += (s*z/q)/120 + x > t && return 5 + return quantile(Binomial(n, p), t) end update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing From df19a27843de6ce766a2d206a22a6dd1dbfad397 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Thu, 6 Mar 2025 16:04:54 +0100 Subject: [PATCH 03/16] Update WeightedSamplingMulti.jl --- src/WeightedSamplingMulti.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 9495a2d..299aa9d 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -139,9 +139,7 @@ end end if s.skip_w <= s.state p = w/s.state - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0)) - k = @inline choose(n, p, c, z) + k = @inline choose(rng, n, p) @inbounds for j in 1:k r = rand(s.rng, j:n) s.value[r], s.value[j] = s.value[j], el From 0da750559cb489f976c48593a1f1c6c385754a27 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Thu, 6 Mar 2025 17:14:41 +0100 Subject: [PATCH 04/16] Update WeightedSamplingMulti.jl --- src/WeightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 299aa9d..8936a59 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -139,7 +139,7 @@ end end if s.skip_w <= s.state p = w/s.state - k = @inline choose(rng, n, p) + k = @inline choose(s.rng, n, p) @inbounds for j in 1:k r = rand(s.rng, j:n) s.value[r], s.value[j] = s.value[j], el From e0302e7f86c0290c07ac94547a9aa2944a02c8c8 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Thu, 6 Mar 2025 17:15:02 +0100 Subject: [PATCH 05/16] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 508c933..ed4b48f 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -114,7 +114,7 @@ end end if s.skip_k < s.seen_k p = 1/s.seen_k - k = @inline choose(rng, n, p) + k = @inline choose(s.rng, n, p) if k == 1 r = rand(s.rng, 1:n) s.value[r] = el From 75e8eb4a9f386f5da07d08e2c652abae11d2a139 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Fri, 14 Mar 2025 16:59:23 +0100 Subject: [PATCH 06/16] Update WeightedSamplingMulti.jl --- src/WeightedSamplingMulti.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 8936a59..9e34bf7 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -140,10 +140,16 @@ end if s.skip_w <= s.state p = w/s.state k = @inline choose(s.rng, n, p) - @inbounds for j in 1:k - r = rand(s.rng, j:n) - s.value[r], s.value[j] = s.value[j], el - update_order_multi!(s, r, j) + if k == 1 + r = rand(s.rng, 1:n) + s.value[r] = el + update_order_single!(s, r) + else + @inbounds for j in 1:k + r = rand(s.rng, j:n) + s.value[r], s.value[j] = s.value[j], el + update_order_multi!(s, r, j) + end end s = @inline recompute_skip!(s, n) end From 1868587b8435f351298cc3ac3a8744e11a5ce99f Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:18:28 +0100 Subject: [PATCH 07/16] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index ed4b48f..5b899bb 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -121,7 +121,7 @@ end update_order_single!(s, r) else @inbounds for j in 1:k - r = rand(s.rng, j:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) s.value[r], s.value[j] = s.value[j], el update_order_multi!(s, r, j) end From 5fd751535a39052561be5a8f4dae41caf6feb38d Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:18:59 +0100 Subject: [PATCH 08/16] Update WeightedSamplingMulti.jl --- src/WeightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 9e34bf7..a916559 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -146,7 +146,7 @@ end update_order_single!(s, r) else @inbounds for j in 1:k - r = rand(s.rng, j:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) s.value[r], s.value[j] = s.value[j], el update_order_multi!(s, r, j) end From 6f01c8132c1f9e6fbe19cf9c155185d8444014a1 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:19:46 +0100 Subject: [PATCH 09/16] Update SamplingUtils.jl --- src/SamplingUtils.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 5c8c729..c8707b9 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -144,4 +144,13 @@ end Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype() Base.eltype(::SeqSampleIter) = Int Base.IteratorSize(::SeqSampleIter) = Base.HasLength() -Base.length(s::SeqSampleIter) = s.n \ No newline at end of file +Base.length(s::SeqSampleIter) = s.n + +function faster_shuffle!(rng::AbstractRNG, vec::AbstractVector) + for i in 2:length(vec) + endi = (i-1) % UInt + j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1 + vec[i], vec[j] = vec[j], vec[i] + end + vec +end From 6edac86ecccd8ec99720d5ced8295b1c1b9e19f0 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:20:08 +0100 Subject: [PATCH 10/16] Update SamplingInterface.jl --- src/SamplingInterface.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index 6547c9f..3584129 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -185,7 +185,7 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me else m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR() s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m)) - return ordered ? s : shuffle!(rng, s) + return ordered ? s : faster_shuffle!(rng, s) end end function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter)) @@ -206,7 +206,7 @@ function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype. vs[i], ps[i] = update_all_p!(s, iters[i]) end ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) + return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) end function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = infer_eltype.(iters)) n_it = length(iters) @@ -217,7 +217,7 @@ function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i]) end ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) + return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) end function update_all!(s, iter) @@ -236,13 +236,13 @@ function update_all!(s, iter, ordered::Bool) for x in iter s = fit!(s, x) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) + return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s)) end function update_all!(s, iter, ordered, wv) for x in iter s = fit!(s, x, wv(x)) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) + return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s)) end function update_all_p!(s, iter) From 385920b6a7e44103eee6de0074235df9fd86a429 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:20:34 +0100 Subject: [PATCH 11/16] Update benchmark_comparison_stream_WWR.jl --- benchmark/benchmark_comparison_stream_WWR.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmark_comparison_stream_WWR.jl b/benchmark/benchmark_comparison_stream_WWR.jl index 307d86e..66c69e3 100644 --- a/benchmark/benchmark_comparison_stream_WWR.jl +++ b/benchmark/benchmark_comparison_stream_WWR.jl @@ -58,7 +58,7 @@ function update_state!(s::SampleMultiAlgAExpJWR, w) end function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR) - return shuffle!(s.rng, last.(s.value.valtree)) + return StreamSampling.faster_shuffle!(s.rng, last.(s.value.valtree)) end a = Iterators.filter(x -> x != 1, 1:10^8) From 5c675780f194cfe48ee5d10c0427141425273cfa Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:26:16 +0100 Subject: [PATCH 12/16] Update WeightedSamplingMulti.jl --- src/WeightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index a916559..76c7da6 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -141,7 +141,7 @@ end p = w/s.state k = @inline choose(s.rng, n, p) if k == 1 - r = rand(s.rng, 1:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) s.value[r] = el update_order_single!(s, r) else From 26fcd3d16236c8a1503a7e154022723501914f84 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:26:35 +0100 Subject: [PATCH 13/16] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 5b899bb..1ca96d9 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -116,7 +116,7 @@ end p = 1/s.seen_k k = @inline choose(s.rng, n, p) if k == 1 - r = rand(s.rng, 1:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) s.value[r] = el update_order_single!(s, r) else From 1b84f76dab311c8e9839412216e94c1efdf45f0e Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:32:18 +0100 Subject: [PATCH 14/16] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 1ca96d9..bed1e78 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -73,7 +73,7 @@ end @inbounds s.value[s.seen_k] = el return s end - j = rand(s.rng, 1:s.seen_k) + j = @inline rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1))) if j <= n @inbounds s.value[j] = el update_order!(s, j) @@ -91,7 +91,7 @@ end return s end if s.skip_k < s.seen_k - j = rand(s.rng, 1:n) + j = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) @inbounds s.value[j] = el update_order!(s, j) s = @inline recompute_skip!(s, n) From dacd5eb265b6e7dc6973b3e99afb2744af374879 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:33:16 +0100 Subject: [PATCH 15/16] Update SortedSamplingSingle.jl --- src/SortedSamplingSingle.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SortedSamplingSingle.jl b/src/SortedSamplingSingle.jl index 26bd98e..0b6674b 100644 --- a/src/SortedSamplingSingle.jl +++ b/src/SortedSamplingSingle.jl @@ -1,6 +1,6 @@ function sorted_sample_single(rng, iter) - k = rand(rng, 1:length(iter)) + k = rand(s.rng, Random.Sampler(s.rng, 1:length(iter), Val(1))) for (i, el) in enumerate(iter) i == k && return el end From 66a899d9a8c3c11c98bd22664853732c2fd79aa4 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 23 Mar 2025 01:33:25 +0100 Subject: [PATCH 16/16] Update SortedSamplingSingle.jl --- src/SortedSamplingSingle.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SortedSamplingSingle.jl b/src/SortedSamplingSingle.jl index 0b6674b..1bb551b 100644 --- a/src/SortedSamplingSingle.jl +++ b/src/SortedSamplingSingle.jl @@ -1,6 +1,6 @@ function sorted_sample_single(rng, iter) - k = rand(s.rng, Random.Sampler(s.rng, 1:length(iter), Val(1))) + k = rand(rng, Random.Sampler(rng, 1:length(iter), Val(1))) for (i, el) in enumerate(iter) i == k && return el end