-
Notifications
You must be signed in to change notification settings - Fork 1
Refactor algorithms to improve numerical stability #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba13d03
dc99265
df19a27
0da7505
e0302e7
75e8eb4
1868587
5fd7515
6f01c81
6edac86
385920b
5c67578
26fcd3d
1b84f76
dacd5eb
66a899d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -144,4 +144,13 @@ end | |||||||||
Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype() | ||||||||||
Base.eltype(::SeqSampleIter) = Int | ||||||||||
Base.IteratorSize(::SeqSampleIter) = Base.HasLength() | ||||||||||
Base.length(s::SeqSampleIter) = s.n | ||||||||||
Base.length(s::SeqSampleIter) = s.n | ||||||||||
|
||||||||||
function faster_shuffle!(rng::AbstractRNG, vec::AbstractVector) | ||||||||||
for i in 2:length(vec) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Review the computation of 'endi' using the modulus operator on (i-1) cast to UInt; consider an explicit conversion (e.g., UInt(i-1)) to ensure the intended upper bound for sampling is correctly computed.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression '(i-1) % UInt' appears to misuse the modulo operator with a type; if the intent is to convert (i-1) to a UInt, consider using 'UInt(i-1)' instead.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
endi = (i-1) % UInt | ||||||||||
j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1 | ||||||||||
vec[i], vec[j] = vec[j], vec[i] | ||||||||||
end | ||||||||||
vec | ||||||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,7 +73,7 @@ end | |
@inbounds s.value[s.seen_k] = el | ||
return s | ||
end | ||
j = rand(s.rng, 1:s.seen_k) | ||
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1))) | ||
if j <= n | ||
@inbounds s.value[j] = el | ||
update_order!(s, j) | ||
|
@@ -91,7 +91,7 @@ end | |
return s | ||
end | ||
if s.skip_k < s.seen_k | ||
j = rand(s.rng, 1:n) | ||
j = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) | ||
@inbounds s.value[j] = el | ||
update_order!(s, j) | ||
s = @inline recompute_skip!(s, n) | ||
|
@@ -114,13 +114,17 @@ end | |
end | ||
if s.skip_k < s.seen_k | ||
p = 1/s.seen_k | ||
z = exp((n-4)*log1p(-p)) | ||
c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) | ||
k = @inline choose(n, p, c, z) | ||
@inbounds for j in 1:k | ||
r = rand(s.rng, j:n) | ||
s.value[r], s.value[j] = s.value[j], el | ||
update_order_multi!(s, r, j) | ||
k = @inline choose(s.rng, n, p) | ||
if k == 1 | ||
r = @inline rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) | ||
s.value[r] = el | ||
update_order_single!(s, r) | ||
else | ||
@inbounds for j in 1:k | ||
r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) | ||
s.value[r], s.value[j] = s.value[j], el | ||
update_order_multi!(s, r, j) | ||
end | ||
end | ||
s = @inline recompute_skip!(s, n) | ||
end | ||
|
@@ -167,18 +171,30 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) | |
return s | ||
end | ||
|
||
function choose(n, p, c, z) | ||
@inline function choose(rng, n, p) | ||
z = exp(n*log1p(-p)) | ||
t = rand(rng, Uniform(z, 1.0)) | ||
s = n*p | ||
q = 1-p | ||
k = z*q*q*q*(q + n*p) | ||
k > c && return 1 | ||
k += n*p*(n-1)*p*z*q*q/2 | ||
k > c && return 2 | ||
k += n*p*(n-1)*p*(n-2)*p*z*q/6 | ||
k > c && return 3 | ||
k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 | ||
k > c && return 4 | ||
b = Binomial(n, p) | ||
return quantile(b, c) | ||
x = z + z*s/q | ||
x > t && return 1 | ||
s *= (n-1)*p | ||
q *= 1-p | ||
x += (s*z/q)/2 | ||
x > t && return 2 | ||
s *= (n-2)*p | ||
q *= 1-p | ||
x += (s*z/q)/6 | ||
x > t && return 3 | ||
s *= (n-3)*p | ||
q *= 1-p | ||
x += (s*z/q)/24 | ||
x > t && return 4 | ||
s *= (n-4)*p | ||
q *= 1-p | ||
x += (s*z/q)/120 | ||
x > t && return 5 | ||
return quantile(Binomial(n, p), t) | ||
end | ||
Comment on lines
+174
to
198
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Consider adding inline comments within the choose function to explain the successive threshold comparisons and the logic behind the decision to return a specific value, which would improve maintainability by clarifying the numerical stability rationale. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
|
||
update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] For consistency with other updated sampling code, consider using '@inline rand(rng, Random.Sampler(rng, j:n, Val(1)))' instead of 'rand(rng, j:n)'.
Copilot uses AI. Check for mistakes.