Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor algorithms to improve numerical stability #113

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
72 changes: 45 additions & 27 deletions benchmark/benchmark_comparison_non_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,68 @@ using CairoMakie
## sequential ##
################

using AliasTables
using Random, StatsBase, Distributions

function weighted_reservoir_sample(rng, a, ws, n)
return shuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1])
end

function weighted_reservoir_sample_seq(rng, a, ws, n)
m = min(length(a), n)
view_w_f_n = @view(ws[1:m])
w_sum = sum(view_w_f_n)
reservoir = sample(rng, @view(a[1:m]), Weights(view_w_f_n, w_sum), n)
length(a) <= n && return reservoir, w_sum
w_skip = @inline skip(rng, w_sum, n)
@assert length(a) >= n
view_w_f_n = @view(ws[1:n])
reservoir = rand(rng, AliasTable(view_w_f_n), n)
w_sum, w_sum_chunk = sum(view_w_f_n), 0.0
w_skip_chunk = skip(rng, w_sum, n)
@inbounds for i in n+1:length(a)
w_el = ws[i]
w_sum += w_el
if w_sum > w_skip
p = w_el/w_sum
q = 1-p
z = exp((n-4)*log1p(-p))
t = rand(rng, Uniform(z*q*q*q*q,1.0))
k = @inline choose(n, p, q, t, z)
@inbounds for j in 1:k
r = rand(rng, j:n)
reservoir[r], reservoir[j] = reservoir[j], a[i]
end
w_skip = @inline skip(rng, w_sum, n)
w_sum_chunk += w_el
if w_sum_chunk > w_skip_chunk
w_sum_new = w_sum + w_sum_chunk
p = w_el/w_sum_new
k = choose(rng, n, p)
if k == 1
reservoir[rand(rng, 1:n)] = a[i]
else
for j in 1:k
reservoir[rand(rng, j:n)], reservoir[j] = reservoir[j], a[i]
end
end
w_skip_chunk = skip(rng, w_sum_new, n)
w_sum, w_sum_chunk = w_sum_new, 0.0
end
end
return reservoir, w_sum
return reservoir, w_sum + w_sum_chunk
end

function skip(rng, w_sum::AbstractFloat, n)
k = exp(-randexp(rng)/n)
return w_sum/k
function skip(rng, w_sum, n)
z = exp(-randexp(rng)/n)
return w_sum*((1-z)/z)
end

function choose(n, p, q, t, z)
x = z*q*q*q*(q + n*p)
@inline function choose(rng, n, p)
z = exp(n*log1p(-p))
t = rand(rng, Uniform(z, 1.0))
s = n*p
q = 1-p
x = z + z*s/q
x > t && return 1
x += n*p*(n-1)*p*z*q*q/2
s *= (n-1)*p
q *= 1-p
x += (s*z/q)/2
x > t && return 2
x += n*p*(n-1)*p*(n-2)*p*z*q/6
s *= (n-2)*p
q *= 1-p
x += (s*z/q)/6
x > t && return 3
x += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24
s *= (n-3)*p
q *= 1-p
x += (s*z/q)/24
x > t && return 4
s *= (n-4)*p
q *= 1-p
x += (s*z/q)/120
x > t && return 5
return quantile(Binomial(n, p), t)
end

Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_comparison_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function update_state!(s::SampleMultiAlgAExpJWR, w)
end

function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR)
return shuffle!(s.rng, last.(s.value.valtree))
return StreamSampling.faster_shuffle!(s.rng, last.(s.value.valtree))
end

a = Iterators.filter(x -> x != 1, 1:10^8)
Expand Down
10 changes: 5 additions & 5 deletions src/SamplingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me
else
m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR()
s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m))
return ordered ? s : shuffle!(rng, s)
return ordered ? s : faster_shuffle!(rng, s)
end
end
function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter))
Expand All @@ -206,7 +206,7 @@ function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype.
vs[i], ps[i] = update_all_p!(s, iters[i])
end
ps /= sum(ps)
return shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
end
function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = infer_eltype.(iters))
n_it = length(iters)
Expand All @@ -217,7 +217,7 @@ function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types =
vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i])
end
ps /= sum(ps)
return shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
return faster_shuffle!(rngs[1], reduce_samples(rngs, ps, vs))
end

function update_all!(s, iter)
Expand All @@ -236,13 +236,13 @@ function update_all!(s, iter, ordered::Bool)
for x in iter
s = fit!(s, x)
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s))
end
function update_all!(s, iter, ordered, wv)
for x in iter
s = fit!(s, x, wv(x))
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
return ordered ? ordvalue(s) : faster_shuffle!(s.rng, value(s))
end

function update_all_p!(s, iter)
Expand Down
11 changes: 10 additions & 1 deletion src/SamplingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,13 @@ end
Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype()
Base.eltype(::SeqSampleIter) = Int
Base.IteratorSize(::SeqSampleIter) = Base.HasLength()
Base.length(s::SeqSampleIter) = s.n
Base.length(s::SeqSampleIter) = s.n

function faster_shuffle!(rng::AbstractRNG, vec::AbstractVector)
for i in 2:length(vec)
endi = (i-1) % UInt
j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1
vec[i], vec[j] = vec[j], vec[i]
end
vec
end
2 changes: 1 addition & 1 deletion src/SortedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

function sorted_sample_single(rng, iter)
k = rand(rng, 1:length(iter))
k = rand(rng, Random.Sampler(rng, 1:length(iter), Val(1)))
for (i, el) in enumerate(iter)
i == k && return el
end
Expand Down
56 changes: 36 additions & 20 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ end
@inbounds s.value[s.seen_k] = el
return s
end
j = rand(s.rng, 1:s.seen_k)
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1)))
if j <= n
@inbounds s.value[j] = el
update_order!(s, j)
Expand All @@ -91,7 +91,7 @@ end
return s
end
if s.skip_k < s.seen_k
j = rand(s.rng, 1:n)
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
@inbounds s.value[j] = el
update_order!(s, j)
s = @inline recompute_skip!(s, n)
Expand All @@ -114,13 +114,17 @@ end
end
if s.skip_k < s.seen_k
p = 1/s.seen_k
z = exp((n-4)*log1p(-p))
c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
k = @inline choose(n, p, c, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
k = @inline choose(s.rng, n, p)
if k == 1
r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
s.value[r] = el
update_order_single!(s, r)
else
@inbounds for j in 1:k
r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1)))
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
end
s = @inline recompute_skip!(s, n)
end
Expand Down Expand Up @@ -167,18 +171,30 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n)
return s
end

function choose(n, p, c, z)
@inline function choose(rng, n, p)
z = exp(n*log1p(-p))
t = rand(rng, Uniform(z, 1.0))
s = n*p
q = 1-p
k = z*q*q*q*(q + n*p)
k > c && return 1
k += n*p*(n-1)*p*z*q*q/2
k > c && return 2
k += n*p*(n-1)*p*(n-2)*p*z*q/6
k > c && return 3
k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24
k > c && return 4
b = Binomial(n, p)
return quantile(b, c)
x = z + z*s/q
x > t && return 1
s *= (n-1)*p
q *= 1-p
x += (s*z/q)/2
x > t && return 2
s *= (n-2)*p
q *= 1-p
x += (s*z/q)/6
x > t && return 3
s *= (n-3)*p
q *= 1-p
x += (s*z/q)/24
x > t && return 4
s *= (n-4)*p
q *= 1-p
x += (s*z/q)/120
x > t && return 5
return quantile(Binomial(n, p), t)
end

update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing
Expand Down
18 changes: 11 additions & 7 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,17 @@ end
end
if s.skip_w <= s.state
p = w/s.state
z = exp((n-4)*log1p(-p))
c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0))
k = @inline choose(n, p, c, z)
@inbounds for j in 1:k
r = rand(s.rng, j:n)
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
k = @inline choose(s.rng, n, p)
if k == 1
r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1)))
s.value[r] = el
update_order_single!(s, r)
else
@inbounds for j in 1:k
r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1)))
s.value[r], s.value[j] = s.value[j], el
update_order_multi!(s, r, j)
end
end
s = @inline recompute_skip!(s, n)
end
Expand Down
Loading