Skip to content

Refactor algorithms to improve numerical stability #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions benchmark/benchmark_comparison_non_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,68 @@ using CairoMakie
## sequential ##
################

using AliasTables
using Random, StatsBase, Distributions

function weighted_reservoir_sample(rng, a, ws, n)
return shuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1])
end

function weighted_reservoir_sample_seq(rng, a, ws, n)
m = min(length(a), n)
view_w_f_n = @view(ws[1:m])
w_sum = sum(view_w_f_n)
reservoir = sample(rng, @view(a[1:m]), Weights(view_w_f_n, w_sum), n)
length(a) <= n && return reservoir, w_sum
w_skip = @inline skip(rng, w_sum, n)
@assert length(a) >= n
view_w_f_n = @view(ws[1:n])
reservoir = rand(rng, AliasTable(view_w_f_n), n)
w_sum, w_sum_chunk = sum(view_w_f_n), 0.0
w_skip_chunk = skip(rng, w_sum, n)
@inbounds for i in n+1:length(a)
w_el = ws[i]
w_sum += w_el
if w_sum > w_skip
p = w_el/w_sum
q = 1-p
z = exp((n-4)*log1p(-p))
t = rand(rng, Uniform(z*q*q*q*q,1.0))
k = @inline choose(n, p, q, t, z)
@inbounds for j in 1:k
r = rand(rng, j:n)
reservoir[r], reservoir[j] = reservoir[j], a[i]
end
w_skip = @inline skip(rng, w_sum, n)
w_sum_chunk += w_el
if w_sum_chunk > w_skip_chunk
w_sum_new = w_sum + w_sum_chunk
p = w_el/w_sum_new
k = choose(rng, n, p)
if k == 1
reservoir[rand(rng, 1:n)] = a[i]
else
for j in 1:k
reservoir[rand(rng, j:n)], reservoir[j] = reservoir[j], a[i]
Copy link
Preview

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] For consistency with other updated sampling code, consider using '@inline rand(rng, Random.Sampler(rng, j:n, Val(1)))' instead of 'rand(rng, j:n)'.

Suggested change
reservoir[rand(rng, j:n)], reservoir[j] = reservoir[j], a[i]
reservoir[@inline rand(rng, Random.Sampler(rng, j:n, Val(1)))], reservoir[j] = reservoir[j], a[i]

Copilot uses AI. Check for mistakes.

end
end
w_skip_chunk = skip(rng, w_sum_new, n)
w_sum, w_sum_chunk = w_sum_new, 0.0
end
end
return reservoir, w_sum
return reservoir, w_sum + w_sum_chunk
end

function skip(rng, w_sum::AbstractFloat, n)
k = exp(-randexp(rng)/n)
return w_sum/k
function skip(rng, w_sum, n)
z = exp(-randexp(rng)/n)
return w_sum*((1-z)/z)
end

function choose(n, p, q, t, z)
x = z*q*q*q*(q + n*p)
@inline function choose(rng, n, p)
z = exp(n*log1p(-p))
t = rand(rng, Uniform(z, 1.0))
s = n*p
q = 1-p
x = z + z*s/q
x > t && return 1
x += n*p*(n-1)*p*z*q*q/2
s *= (n-1)*p
q *= 1-p
x += (s*z/q)/2
x > t && return 2
x += n*p*(n-1)*p*(n-2)*p*z*q/6
s *= (n-2)*p
q *= 1-p
x += (s*z/q)/6
x > t && return 3
x += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24
s *= (n-3)*p
q *= 1-p
x += (s*z/q)/24
x > t && return 4
s *= (n-4)*p
q *= 1-p
x += (s*z/q)/120
x > t && return 5
return quantile(Binomial(n, p), t)
end

Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_comparison_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function update_state!(s::SampleMultiAlgAExpJWR, w)
end

function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR)
return shuffle!(s.rng, last.(s.value.valtree))
return StreamSampling.faster_shuffle!(s.rng, last.(s.value.valtree))
end

a = Iterators.filter(x -> x != 1, 1:10^8)
Expand Down
10 changes: 5 additions & 5 deletions src/SamplingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me
else
m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR()
s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m))
return ordered ? s : shuffle!(rng, s)
return ordered ? s : faster_shuffle!(rng, s)
end
end
function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter))
Expand All @@ -206,7 +206,7 @@ function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype.
vs[i], ps[i] = update_all_p!(s, iters[i])
end
ps /= sum(ps)
return shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
end
function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = infer_eltype.(iters))
n_it = length(iters)
Expand All @@ -217,7 +217,7 @@ function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types =
vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i])
end
ps /= sum(ps)
return shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
end

function update_all!(s, iter)
Expand All @@ -236,13 +236,13 @@ function update_all!(s, iter, ordered::Bool)
for x in iter
s = fit!(s, x)
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s))
end
function update_all!(s, iter, ordered, wv)
for x in iter
s = fit!(s, x, wv(x))
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s))
end

function update_all_p!(s, iter)
Expand Down
11 changes: 10 additions & 1 deletion src/SamplingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,13 @@ end
Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype()
Base.eltype(::SeqSampleIter) = Int
Base.IteratorSize(::SeqSampleIter) = Base.HasLength()
Base.length(s::SeqSampleIter) = s.n
Base.length(s::SeqSampleIter) = s.n

function faster_shuffle!(rng::AbstractRNG, vec::AbstractVector)
for i in 2:length(vec)
Copy link
Preview

Copilot AI May 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Review the computation of 'endi' using the modulus operator on (i-1) cast to UInt; consider an explicit conversion (e.g., UInt(i-1)) to ensure the intended upper bound for sampling is correctly computed.

Suggested change
endi = (i-1) % UInt
endi = UInt(i-1) % UInt

Copilot uses AI. Check for mistakes.

Copy link
Preview

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression '(i-1) % UInt' appears to misuse the modulo operator with a type; if the intent is to convert (i-1) to a UInt, consider using 'UInt(i-1)' instead.

Suggested change
endi = (i-1) % UInt
endi = UInt(i-1)

Copilot uses AI. Check for mistakes.

endi = (i-1) % UInt
j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1
vec[i], vec[j] = vec[j], vec[i]
end
vec
end
2 changes: 1 addition & 1 deletion src/SortedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

function sorted_sample_single(rng, iter)
k = rand(rng, 1:length(iter))
k = rand(rng, Random.Sampler(rng, 1:length(iter), Val(1)))
for (i, el) in enumerate(iter)
i == k && return el
end
Expand Down
56 changes: 36 additions & 20 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ end
@inbounds s.value[s.seen_k] = el
return s
end
j = rand(s.rng, 1:s.seen_k)
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1)))
if j <= n
@inbounds s.value[j] = el
update_order!(s, j)
Expand All @@ -91,7 +91,7 @@ end
return s
end
if s.skip_k < s.seen_k
j = rand(s.rng, 1:n)
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
@inbounds s.value[j] = el
update_order!(s, j)
s = @inline recompute_skip!(s, n)
Expand All @@ -114,13 +114,17 @@ end
end
if s.skip_k < s.seen_k
p = 1/s.seen_k
z = exp((n-4)*log1p(-p))
c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
k = @inline choose(n, p, c, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
k = @inline choose(s.rng, n, p)
if k == 1
r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
s.value[r] = el
update_order_single!(s, r)
else
@inbounds for j in 1:k
r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1)))
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
end
s = @inline recompute_skip!(s, n)
end
Expand Down Expand Up @@ -167,18 +171,30 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n)
return s
end

function choose(n, p, c, z)
@inline function choose(rng, n, p)
z = exp(n*log1p(-p))
t = rand(rng, Uniform(z, 1.0))
s = n*p
q = 1-p
k = z*q*q*q*(q + n*p)
k > c && return 1
k += n*p*(n-1)*p*z*q*q/2
k > c && return 2
k += n*p*(n-1)*p*(n-2)*p*z*q/6
k > c && return 3
k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24
k > c && return 4
b = Binomial(n, p)
return quantile(b, c)
x = z + z*s/q
x > t && return 1
s *= (n-1)*p
q *= 1-p
x += (s*z/q)/2
x > t && return 2
s *= (n-2)*p
q *= 1-p
x += (s*z/q)/6
x > t && return 3
s *= (n-3)*p
q *= 1-p
x += (s*z/q)/24
x > t && return 4
s *= (n-4)*p
q *= 1-p
x += (s*z/q)/120
x > t && return 5
return quantile(Binomial(n, p), t)
end
Comment on lines +174 to 198
Copy link
Preview

Copilot AI May 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider adding inline comments within the choose function to explain the successive threshold comparisons and the logic behind the decision to return a specific value, which would improve maintainability by clarifying the numerical stability rationale.

Copilot uses AI. Check for mistakes.


update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing
Expand Down
18 changes: 11 additions & 7 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,17 @@ end
end
if s.skip_w <= s.state
p = w/s.state
z = exp((n-4)*log1p(-p))
c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0))
k = @inline choose(n, p, c, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
k = @inline choose(s.rng, n, p)
if k == 1
r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
s.value[r] = el
update_order_single!(s, r)
else
@inbounds for j in 1:k
r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1)))
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
end
s = @inline recompute_skip!(s, n)
end
Expand Down
Loading