Skip to content

Track max level more lazily #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: lh/anew-dev-2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions src/DynamicDiscreteSamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ Base.setindex!(w::Weights, v, i::Int) = (_setindex!(w.m, Float64(v), i); w)

# Select level
x = @inline rand(rng, Random.Sampler(rng, Base.OneTo(m[4]), Val(1)))

i = _convert(Int, m[2])
mi = m[i]
if mi == 0
i = set_last_nonzero_level_decrease!(m, i)
mi = m[i]
end

@inbounds while i > 5
x <= mi && break
x -= mi
Expand Down Expand Up @@ -264,6 +270,19 @@ function _rand_slow_path(rng::AbstractRNG, m::Memory{UInt64}, i)
end
end

function set_last_nonzero_level_decrease!(m, m2)
level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(m2-5)
chunk = m[level_weights_nonzero_index]
while chunk == 0 # Find the new m[2]
m2 -= 64
level_weights_nonzero_index -= 1
chunk = m[level_weights_nonzero_index]
end
m2 += 63 - trailing_zeros(chunk) - level_weights_nonzero_subindex - 1
m[2] = _convert(UInt64, m2)
return m2
end

function _getindex(m::Memory{UInt64}, i::Int)
@boundscheck 1 <= i <= m[1] || throw(BoundsError(_FixedSizeWeights(m), i))
j = i + 10523
Expand Down Expand Up @@ -522,6 +541,12 @@ function set_global_shift_increase!(m::Memory, m2, m3::UInt64, m4) # Increase sh
end

function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease shift, on insertion of elements

m2 = _convert(Int, m[2])
if m[m2] == 0
m2 = set_last_nonzero_level_decrease!(m, m2)
end

m3_old = m[3]
m[3] = m3
@assert signed(m3) < signed(m3_old)
Expand All @@ -530,7 +555,6 @@ function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease s
# In any case, this only adjusts elements at or before m[2]
# from the first index that previously could have had a weight > 1 to min(m[2], the first index that can't have a weight > 1) (never empty), set weights to 1 or 0
# from the first index that could have a weight > 1 to m[2] (possibly empty), shift weights by delta.
m2 = signed(m[2])
i1 = -signed(m3)-117 # see above, this is the first index that could have weight > 1 (anything after this will have weight 1 or 0)
i1_old = -signed(m3_old)-117 # anything before this is already weight 1 or 0
flatten_range = max(i1_old, 5):min(m2, i1-1)
Expand Down Expand Up @@ -566,7 +590,7 @@ Base.@propagate_inbounds function update_weight!(m::Memory{UInt64}, i, shifted_s
end

get_alloced_indices(exponent::UInt64) = _convert(Int, 10268 + exponent >> 3), exponent << 3 & 0x38
get_level_weights_nonzero_indices(exponent::UInt64) = _convert(Int, 10235 + exponent >> 6), exponent & 0x3f
get_level_weights_nonzero_indices(exponent) = _convert(Int, 10235 + exponent >> 6), exponent & 0x3f

function _set_to_zero!(m::Memory, i::Int)
# Find the entry's pos in the edit map table
Expand All @@ -586,23 +610,9 @@ function _set_to_zero!(m::Memory, i::Int)
m4 = m[4]
m4 -= old_weight
if significand_sum == 0 # We zeroed out a group
level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(exponent)
chunk = m[level_weights_nonzero_index] &= ~(0x8000000000000000 >> level_weights_nonzero_subindex)
m[weight_index] = 0
if m4 == 0 # There are no groups left
m[2] = 4
else
m2 = m[2]
if weight_index == m2 # We zeroed out the first group
while chunk == 0 # Find the new m[2]
level_weights_nonzero_index -= 1
m2 -= 64
chunk = m[level_weights_nonzero_index]
end
m2 += 63-trailing_zeros(chunk) - level_weights_nonzero_subindex
m[2] = m2
end
end
level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(exponent)
m[level_weights_nonzero_index] &= ~(0x8000000000000000 >> level_weights_nonzero_subindex)
else # We did not zero out a group
shift = signed(exponent + m[3])
new_weight = _convert(UInt64, significand_sum << shift) + 1
Expand Down
28 changes: 18 additions & 10 deletions test/invariants.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
isdefined(@__MODULE__, :Memory) || const Memory = Vector # Compat for Julia < 1.11
_get_UInt128(m::Memory, i::Integer) = UInt128(m[i]) | (UInt128(m[i+1]) << 64)
_length_from_memory(allocated_memory::Integer) = Int((allocated_memory-10523)/7)
function verify_weights(m::Memory)
function verify_weights(w::DynamicDiscreteSamplers.Weights)
m = w.m
Comment on lines +4 to +5
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems unrelated and I'd rather not make it because it will make throwing in verify(m) calls into src internals harder which will reduce its utility in debugging

m3 = m[3]
for i in 5:2050
shift = signed(i - 4 + m3)
Expand All @@ -14,10 +15,16 @@ function verify_weights(m::Memory)
end
end

function verify_m2(m::Memory)
@assert m[2] == findlast(i -> i == 4 || m[i] != 0, 1:2050)
function verify_m2(w::DynamicDiscreteSamplers.Weights)
m = w.m
@assert m[2] >= findlast(i -> i == 4 || m[i] != 0, 1:2050)
if m[4] != 0
rand(w)
@assert m[2] == findlast(i -> m[i] != 0, 1:2050)
end
end
function verify_m4(m::Memory)
function verify_m4(w::DynamicDiscreteSamplers.Weights)
m = w.m
m4 = zero(UInt64)
for i in 5:2050
m4 = Base.checked_add(m4, m[i])
Expand All @@ -26,7 +33,8 @@ function verify_m4(m::Memory)
# @assert m4 == 0 || UInt64(2)^32 <= m4 # This invariant is now maintained loosely and lazily
end

function verify_edit_map_points_to_correct_target(m::Memory)
function verify_edit_map_points_to_correct_target(w::DynamicDiscreteSamplers.Weights)
m = w.m
filled_len = m[1]
len = _length_from_memory(length(m))
for i in 1:len
Expand All @@ -39,9 +47,9 @@ function verify_edit_map_points_to_correct_target(m::Memory)
end
end

function verify(m::Memory)
verify_weights(m)
verify_m2(m)
verify_m4(m)
verify_edit_map_points_to_correct_target(m)
function verify(w::DynamicDiscreteSamplers.Weights)
verify_weights(w)
verify_m2(w)
verify_m4(w)
verify_edit_map_points_to_correct_target(w)
end
14 changes: 7 additions & 7 deletions test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ w = DynamicDiscreteSamplers.ResizableWeights(31)
w[11] = 9.923269000574892e-8
w[23] = 0.9876032886161744
w[31] = 1.1160998022859043
verify(w.m)
verify(w)

w = DynamicDiscreteSamplers.FixedSizeWeights(10)
w[1] = floatmin(Float64)
w[2] = floatmax(Float64)
w[2] = 0 # This previously threw an assertion error due to overflow when estimating sum of level weights
verify(w.m)
verify(w)

w = DynamicDiscreteSamplers.FixedSizeWeights(9)
v = zeros(9)
Expand All @@ -232,20 +232,20 @@ v[8] = w[8] = 0.72
v[5] = w[5] = 0.20
v[1] = w[1] = 0.71
v[3] = w[3] = 0.92
verify(w.m)
verify(w)
@test v == w

w = DynamicDiscreteSamplers.ResizableWeights(2)
w[1] = 0.95
w[2] = 6.41e14
verify(w.m)
verify(w)

# This test catches a bug that was not revealed by the RNG tests below
w = DynamicDiscreteSamplers.FixedSizeWeights(3);
w[1] = 1.5
w[2] = prevfloat(1.5)
w[3] = 2^25
verify(w.m)
verify(w)

# This test catches a bug that was not revealed by the RNG tests below.
# The final line is calibrated to have about a 50% fail rate on that bug
Expand All @@ -255,7 +255,7 @@ w .= repeat(ldexp.(1.0, -1022:1023), inner=2048)
w[(2046-16)*2048+1:2046*2048] .= 0
@test w.m[4] < 2.0^32*1.1 # Confirm that we created an interesting condition
f(w,n) = sum(Int64(rand(w)) for _ in 1:n)
verify(w.m)
verify(w)
@test f(w, 2^27) ≈ 4.1543685e6*2^27 rtol=1e-6 # This should fail less than 1e-40 of the time

# These tests have never revealed a bug that was not revealed by one of the above tests:
Expand Down Expand Up @@ -304,7 +304,7 @@ try
resize = rand(Bool) # Some behavior emerges when not resizing for a long period
for _ in 1:rand((10,100,3000))
@test v == w
verify(w.m)
verify(w)
if rand() < .01
sm = sum(v)
sm == 0 || statistical_test(w, v ./ sm)
Expand Down
Loading