diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 4d395d7a..349662b7 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -38,11 +38,11 @@ jobs: - name: Run benchmarks run: | mkdir results - benchpkg --add https://github.com/LilithHafner/ChairmarksForAirspeedVelocity.jl ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.pull_request.head.sha}}" --output-dir=results/ --tune + benchpkg --add https://github.com/LilithHafner/ChairmarksForAirspeedVelocity.jl ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.pull_request.head.sha}}" --output-dir=results/ --tune - name: Create plots from benchmarks run: | mkdir -p plots - benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ + benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ - name: Upload plot as artifact uses: actions/upload-artifact@v4 with: @@ -50,7 +50,7 @@ jobs: path: plots - name: Create markdown table from benchmarks run: | - benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md + benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md echo '### Benchmark Results' > body.md echo '' >> body.md echo '' >> body.md diff --git a/.github/workflows/benchmark_pr_2.yml b/.github/workflows/benchmark_pr_2.yml new file mode 100644 index 00000000..bd4b2d55 --- /dev/null +++ b/.github/workflows/benchmark_pr_2.yml @@ -0,0 +1,86 @@ +name: Benchmark a pull request (2) +# Keep this in sync with benchmark_push.yml + +on: + pull_request: +concurrency: + # Skip and cancel intermediate builds: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + pull-requests: write + +jobs: + generate_plots: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + - uses: julia-actions/cache@v2 + - name: Extract Package Name from Project.toml + id: extract-package-name + run: | + PACKAGE_NAME=$(grep "^name" Project.toml | sed 's/^name = "\(.*\)"$/\1/') + echo "::set-output name=package_name::$PACKAGE_NAME" + - name: Build AirspeedVelocity + env: + JULIA_NUM_THREADS: 2 + run: | + # TODO: cache this build step and skip it if the cache hits (but still receive updates) + # Lightweight build step, as sometimes the runner runs out of memory: + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.add(;url="https://github.com/MilesCranmer/AirspeedVelocity.jl.git")' + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.build("AirspeedVelocity")' + - name: Add ~/.julia/bin to PATH + run: | + echo "$HOME/.julia/bin" >> $GITHUB_PATH + - name: Run benchmarks + run: | + mkdir results + benchpkg --add https://github.com/LilithHafner/ChairmarksForAirspeedVelocity.jl ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.pull_request.head.sha}}" --output-dir=results/ --tune + - name: Create plots from benchmarks + run: | + mkdir -p plots + benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ + - name: Upload plot as artifact + uses: actions/upload-artifact@v4 + with: + name: plots + path: plots + - name: Create markdown table from benchmarks + run: | + benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md + echo '### Benchmark Results' > body.md + echo '' >> body.md + echo '' >> body.md + cat table.md >> body.md + echo '' >> body.md + echo '' >> body.md + echo '### Benchmark Plots' >> body.md + echo 'A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.' >> body.md + echo 'Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).' >> body.md + + - name: wait + run: sleep 45 + + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fcbenchmark + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Benchmark Results + + - name: Join Tables + run: | + echo '${{ steps.fcbenchmark.outputs.comment-body }}' >> old_body.md + julia .github/workflows/join_table.jl body.md old_body.md merged_body.md + + - name: Comment on PR + uses: peter-evans/create-or-update-comment@v4 + with: + comment-id: ${{ steps.fcbenchmark.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body-path: merged_body.md + edit-mode: replace diff --git a/.github/workflows/benchmark_pr_3.yml b/.github/workflows/benchmark_pr_3.yml new file mode 100644 index 00000000..7255f365 --- /dev/null +++ b/.github/workflows/benchmark_pr_3.yml @@ -0,0 +1,86 @@ +name: Benchmark a pull request (3) +# Keep this in sync with benchmark_push.yml + +on: + pull_request: +concurrency: + # Skip and cancel intermediate builds: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + pull-requests: write + +jobs: + generate_plots: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + - uses: julia-actions/cache@v2 + - name: Extract Package Name from Project.toml + id: extract-package-name + run: | + PACKAGE_NAME=$(grep "^name" Project.toml | sed 's/^name = "\(.*\)"$/\1/') + echo "::set-output name=package_name::$PACKAGE_NAME" + - name: Build AirspeedVelocity + env: + JULIA_NUM_THREADS: 2 + run: | + # TODO: cache this build step and skip it if the cache hits (but still receive updates) + # Lightweight build step, as sometimes the runner runs out of memory: + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.add(;url="https://github.com/MilesCranmer/AirspeedVelocity.jl.git")' + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.build("AirspeedVelocity")' + - name: Add ~/.julia/bin to PATH + run: | + echo "$HOME/.julia/bin" >> $GITHUB_PATH + - name: Run benchmarks + run: | + mkdir results + benchpkg --add https://github.com/LilithHafner/ChairmarksForAirspeedVelocity.jl ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.pull_request.head.sha}}" --output-dir=results/ --tune + - name: Create plots from benchmarks + run: | + mkdir -p plots + benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ + - name: Upload plot as artifact + uses: actions/upload-artifact@v4 + with: + name: plots + path: plots + - name: Create markdown table from benchmarks + run: | + benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md + echo '### Benchmark Results' > body.md + echo '' >> body.md + echo '' >> body.md + cat table.md >> body.md + echo '' >> body.md + echo '' >> body.md + echo '### Benchmark Plots' >> body.md + echo 'A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.' >> body.md + echo 'Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).' >> body.md + + - name: wait + run: sleep 90 + + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fcbenchmark + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Benchmark Results + + - name: Join Tables + run: | + echo '${{ steps.fcbenchmark.outputs.comment-body }}' >> old_body.md + julia .github/workflows/join_table.jl body.md old_body.md merged_body.md + + - name: Comment on PR + uses: peter-evans/create-or-update-comment@v4 + with: + comment-id: ${{ steps.fcbenchmark.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body-path: merged_body.md + edit-mode: replace diff --git a/.github/workflows/join_table.jl b/.github/workflows/join_table.jl new file mode 100644 index 00000000..2b8cc511 --- /dev/null +++ b/.github/workflows/join_table.jl @@ -0,0 +1,39 @@ +new = replace(String(read(Sys.ARGS[1])), ".../"=>" / ", "..."=>"") +old = replace(String(read(Sys.ARGS[2])), ".../"=>" / ", "..."=>"") + +function combine((n,o)) + @show n o # Debug print statement, just in case. + if count(==('|'), n) <= 3 + # @assert n == o + return n + end + + n_cols = split(n, '|') + o_cols = split(o, '|') + # @assert length(n_cols) == length(o_cols) + # @assert isempty(first(n_cols)) + # @assert isempty(last(n_cols)) + # @assert isempty(first(o_cols)) + # @assert isempty(last(o_cols)) + # @assert n_cols[2] == o_cols[2] + + if all(isspace, n_cols[2]) || all(∈([':','-']), n_cols[2]) + # @assert n == o + return n + end + + o_data = strip(o_cols[end-1]) + n_data = strip(n_cols[end-1]) + n_cols[end-1] = if o_data == n_data * "," * n_data # If all three results are the same, only report one + n_data + else + o_data * "," * n_data + end + join(n_cols, '|') +end + +new2 = join(combine.(zip(split(new, '\n'), split(old, '\n'))), '\n') + +open(Sys.ARGS[3], "w") do io + write(io, new2) +end diff --git a/.gitignore b/.gitignore index 95731a59..42e91728 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ /Manifest.toml /docs/Manifest.toml /docs/build/ +results_*@*.json +results_*@*.json.tmp diff --git a/Project.toml b/Project.toml index f4d3f682..253ea9ac 100644 --- a/Project.toml +++ b/Project.toml @@ -4,14 +4,10 @@ authors = ["Lilith Orion Hafner , Adriano Meligrana > 1 + j + end + res +end + +function rand_update(ds, σ) + j = rand(ds) + delete!(ds, j) + push!(ds, j, exp(σ*randn())) + j +end + +function intermixed_h(n, σ) + ds = DynamicDiscreteSampler() + elements = Set{Int}() + res = 0 + for i in 1:n + if rand() < 0.5 + element = rand(1:n) + if element ∉ elements + push!(ds, element, exp(σ*randn())) + push!(elements, element) + end + elseif length(elements) > 0 + element = rand(elements) + delete!(ds, element) + delete!(elements, element) + end + if length(elements) > 0 + res += rand(ds) + end + end + res +end + for n in [100, 1000, 10000], σ in [.1, 1, 10, 100] - SUITE["constructor n=$n σ=$σ"] = @benchmarkable n,σ gaussian_weights_sequential_push(_...) + # TODO: try to use min over noise, average over rng, and max over treatment in analysis + SUITE["constructor n=$n σ=$σ"] = @benchmarkable gaussian_weights_sequential_push($n, $σ) SUITE["sample n=$n σ=$σ"] = @benchmarkable gaussian_weights_sequential_push(n, σ) rand + SUITE["delete ∘ rand n=$n σ=$σ"] = @benchmarkable gaussian_weights_sequential_push(n, σ) rand_delete(_, $n) evals=1 + SUITE["update ∘ rand n=$n σ=$σ"] = @benchmarkable gaussian_weights_sequential_push(n, σ) rand_update(_, $σ) evals=n + SUITE["intermixed_h n=$n σ=$σ"] = @benchmarkable intermixed_h($n, $σ) + SUITE["summarysize n=$n σ=$σ"] = ChairmarksForAirspeedVelocity.Runnable() do + vector_to_trial([3600Base.summarysize(gaussian_weights_sequential_push(n, σ)) for _ in 1:1_000_000÷n]) + end +end + +function pathological1_setup() + ds = DynamicDiscreteSampler() + push!(ds, 1, 1e50) + ds +end +function pathological1_update(ds) + push!(ds, 2, 1e100) + delete!(ds, 2) +end +SUITE["pathological 1"] = @benchmarkable pathological1_setup pathological1_update +function pathological1′_update(ds) + push!(ds, 2, 1e100) + delete!(ds, 2) + rand(ds) end +SUITE["pathological 1′"] = @benchmarkable pathological1_setup pathological1′_update + +function pathological2_setup() + ds = DynamicDiscreteSampler() + push!(ds, 1, 1e-300) + ds +end +function pathological2_update(ds) + push!(ds, 2, 1e300) + delete!(ds, 2) +end +SUITE["pathological 2"] = @benchmarkable pathological2_setup pathological2_update +function pathological2′_update(ds) + push!(ds, 2, 1e300) + delete!(ds, 2) + rand(ds) +end +SUITE["pathological 2′"] = @benchmarkable pathological2_setup pathological2′_update + +pathological3 = DynamicDiscreteSampler() +push!(pathological3, 1, 1e300) +delete!(pathological3, 1) +push!(pathological3, 1, 1e-300) +SUITE["pathological 3"] = @benchmarkable pathological3 rand + +function pathological4_setup() + ds = DynamicDiscreteSampler() + push!(ds, 1, 1e-270) + ds +end +function pathological4_update(ds) + push!(ds, 2, 1e307) + delete!(ds, 2) +end +SUITE["pathological 4"] = @benchmarkable pathological4_setup pathological4_update +function pathological4′_update(ds) + push!(ds, 2, 1e307) + delete!(ds, 2) + rand(ds) +end +SUITE["pathological 4′"] = @benchmarkable pathological4_setup pathological4′_update + + +function pathological5a_setup() + ds = DynamicDiscreteSampler() + push!(ds, 1, 2.0^-32) + push!(ds, 2, 1.0) + ds +end +function pathological5a_update(ds) + push!(ds, 3, 2.0^18) + delete!(ds, 3) +end +SUITE["pathological 5a"] = @benchmarkable pathological5a_setup pathological5a_update +function pathological5b_setup() + ds = DynamicDiscreteSampler() + for i in 128:-1:1 + push!(ds, i, 2.0^-i) + end + ds +end +function pathological5b_update(ds) + push!(ds, 129, 2.0^30) + delete!(ds, 129) +end +SUITE["pathological 5b"] = @benchmarkable pathological5b_setup pathological5b_update +function pathological5b′_update(ds) + push!(ds, 129, 2.0^48) + delete!(ds, 129) + rand(ds) +end +SUITE["pathological 5b′"] = @benchmarkable pathological5b_setup pathological5b′_update + +include("code_size.jl") +_code_size = code_size(dirname(pathof(DynamicDiscreteSamplers))) + +SUITE["code size in lines"] = constant(3600_code_size.lines) +SUITE["code size in bytes"] = constant(3600_code_size.bytes) +SUITE["code size in syntax nodes"] = constant(3600_code_size.syntax_nodes) diff --git a/benchmark/code_size.jl b/benchmark/code_size.jl new file mode 100644 index 00000000..f1d822df --- /dev/null +++ b/benchmark/code_size.jl @@ -0,0 +1,44 @@ +using Base.JuliaSyntax + +add(a, b) = (lines=a.lines+b.lines, bytes=a.bytes+b.bytes, syntax_nodes=a.syntax_nodes+b.syntax_nodes) +function code_size(file_or_dir) + if isdir(file_or_dir) + reduce(add, code_size.(readdir(file_or_dir, join=true))) + elseif isfile(file_or_dir) + code_size_file(file_or_dir) + end +end +function code_size_file(file) + text = String(read(file)) + tokens = tokenize(text) + + byte_has_code = trues(ncodeunits(text)) + + last_end = 0 + for t in tokens + last_end = last(t.range) + if kind(t) ∈ JuliaSyntax.KSet"Comment Whitespace NewlineWs" + byte_has_code[t.range] .= false + end + end + + syntax_nodes = 0 + + stack = parseall(SyntaxNode, text, ignore_warnings=true).children + while !isempty(stack) + x = pop!(stack) + if kind(x) == K"doc" + d = x.children[1] + rng = (d.position-1) .+ (1:JuliaSyntax.span(d)) + byte_has_code[rng] .= false + x = x.children[2] + end + x.children !== nothing && append!(stack, x.children) + syntax_nodes += 1 + end + + newlines = vcat(0, findall(==('\n'), text), ncodeunits(text)) + lines = count(any(view(byte_has_code, newlines[i]+1:newlines[i+1])) for i in 1:length(newlines)-1) + + (lines=lines, bytes=sum(byte_has_code), syntax_nodes) +end diff --git a/src/DynamicDiscreteSamplers.jl b/src/DynamicDiscreteSamplers.jl index ecc366ac..3569a637 100644 --- a/src/DynamicDiscreteSamplers.jl +++ b/src/DynamicDiscreteSamplers.jl @@ -1,492 +1,823 @@ module DynamicDiscreteSamplers -export DynamicDiscreteSampler, SamplerIndices +export DynamicDiscreteSampler -using Distributions, Random, StaticArrays +using Random -struct SelectionSampler{N} - p::MVector{N, Float64} - o::MVector{N, Int16} +isdefined(@__MODULE__, :Memory) || const Memory = Vector # Compat for Julia < 1.11 + +const DEBUG = Base.JLOptions().check_bounds == 1 +_convert(T, x) = DEBUG ? T(x) : x%T + +""" + Weights <: AbstractVector{Float64} + +An abstract vector capable of storing normal, non-negative floating point numbers on which +`rand` samples an index according to values rather than sampling a value uniformly. +""" +abstract type Weights <: AbstractVector{Float64} end +struct FixedSizeWeights <: Weights + m::Memory{UInt64} + global _FixedSizeWeights + _FixedSizeWeights(m::Memory{UInt64}) = new(m) end -function Base.rand(rng::AbstractRNG, ss::SelectionSampler, lastfull::Int) - u = rand(rng)*ss.p[lastfull] - @inbounds for i in lastfull-1:-1:1 - ss.p[i] < u && return i+1 - end - return 1 +struct SemiResizableWeights <: Weights + m::Memory{UInt64} + SemiResizableWeights(w::FixedSizeWeights) = new(w.m) end -function set_cum_weights!(ss::SelectionSampler, ns, reorder) - p, lastfull = ns.sampled_level_weights, ns.track_info.lastfull - if reorder - ns.track_info.reset_order = 0 - if !ns.track_info.reset_distribution && issorted(@view(p[1:lastfull])) - return ss - end - @inline reorder_levels(ns, ss, p, lastfull) - ns.track_info.firstchanged = 1 +mutable struct ResizableWeights <: Weights + m::Memory{UInt64} + ResizableWeights(w::FixedSizeWeights) = new(w.m) +end + +#===== Overview ====== + +# Objective + +This package provides a discrete random sampler with the following key properties + - Exact: sampling probability exactly matches provided weights + - O(1) worst case expected runtime for sampling + (though termination only guaranteed probabilistically) + - O(1) worst case amortized update time to change the weight of any element + (an individual update may take up to O(n)) + - O(n) space complexity + - O(n) construction time + - Fast constant factor in practice. Typical usage has a constant factor of tens of clock + cycles and pathological usage has a constant factor of thousands of clock cycles. + + +# Brief implementation overview + +Weights are are divided into levels according to their exponents. To sample, first sample a +level and then sample an element within that level. + + +# Definition of terms + +v::Float64 aka weight + An entry in a Weights object set with `w[i] = v`, retrieved with `v = w[i]`. +exponent::UInt64 + The exponent of a weight is `reinterpret(UInt64, weight) >> 52`. + Note that this is _not_ the same as `Base.exponent(weight)` nor + `reinterpret(UInt64, weight) & Base.exponent_mask(Float64)`. +level + All the weights in a Weights object that have the same exponent. +significand::UInt64 + The significand of a weight is `reinterpret(UInt64, weight) << 11 | 0x8000000000000000`. + Ranges from 0x8000000000000000 for 1.0 to 0xfffffffffffff800 for 1.9999... + The ratio of two weights with the same exponent is the ratio of their significands. +significand_sum::UInt128 + The sum of the significands of the weights in a given level. + Is 0 for empty levels and otherwise ranges from widen(0x8000000000000000) to + widen(0xfffffffffffff800) * length(weights). +weight + Refers to the relative likely hood of an event. Weights are like probabilities except + they do not need to sum to one. In this codebase, the term "weight" is used to refer to + four things: the weight of an element relative to all the other elements in in a + Weights object; the weight of an element relative to the other elements in its level; + the weight of a level relative to the other levels as defined by level_weights; and the + weight of a level relative to the other levels as defined by significand_sums. + + +# Implementation and data structure overview + +Weights are normal, non-negative Float64s. They are divided into levels according to their +exponents. Each level has a weight which is the exact sum of the weights in that level. We +can't represent this sum exactly as a Float64 so we represent it as significand_sum::UInt128 +which is the sum of the significands of the weights in that level. To get the level's weight, +compute big(significand_sum)< +# 1 length::Int +# 2 max_level::Int # absolute pointer to the last element of level weights that is nonzero +# 3 shift::Int level weights are equal to significand_sums<<(exponent+shift), plus one if significand_sum is not zero +# 4 sum(level weights)::UInt64 +# 5..2050 level weights::[UInt64 2046] # earlier is lower. first is exponent 0x001, last is exponent 0x7fe. Subnormal are not supported (TODO). +# 2051..6142 significand_sums::[UInt128 2046] # sum of significands (the maximum significand contributes 0xfffffffffffff800) +# 6143..10234 level location info::[NamedTuple{pos::Int, length::Int} 2046] indexes into sub_weights, pos is absolute into m. +# 10235..10266 level_weights_nonzero::[Bool 2046] # map of which levels have nonzero weight (used to bump m2 efficiently when a level is zeroed out) +# 2 unused bits + +# gc info: +# 10267 next_free_space::Int (used to re-allocate) +# 16 unused bits +# 10268..10523 level allocated length::[UInt8 2046] (2^(x-1) is implied) + +# 10524..10523+len edit_map (maps index to current location in sub_weights)::[(pos<<11 + exponent)::UInt64] (zero means zero; fixed location, always at the start. Force full realloc when it OOMs. (len refers to allocated length, not m[1]) + +# 10524+2len..10523+7len sub_weights (woven with targets)::[[significand::UInt64, target::Int}]]. allocated_len == length_from_memory(length(m)) (len refers to allocated length, not m[1]) + +# significands are stored in sub_weights with their implicit leading 1 added +# element_from_sub_weights = 0x8000000000000000 | (reinterpret(UInt64, weight::Float64) << 11) +# And sampled with +# rand(UInt64) < element_from_sub_weights +# this means that for the lowest normal significand (52 zeros with an implicit leading one), +# achieved by 2.0, 4.0, etc the significand stored in sub_weights is 0x8000000000000000 +# and there are 2^63 pips less than that value (1/2 probability). For the +# highest normal significand (52 ones with an implicit leading 1) the significand +# stored in sub_weights is 0xfffffffffffff800 and there are 2^64-2^11 pips less than +# that value for a probability of (2^64-2^11) / 2^64 == (2^53-1) / 2^53 == prevfloat(2.0)/2.0 +@assert 0xfffffffffffff800//big(2)^64 == (UInt64(2)^53-1)//UInt64(2)^53 == big(prevfloat(2.0))/big(2.0) +@assert 0x8000000000000000 | (reinterpret(UInt64, 1.0::Float64) << 11) === 0x8000000000000000 +@assert 0x8000000000000000 | (reinterpret(UInt64, prevfloat(1.0)::Float64) << 11) === 0xfffffffffffff800 +# significand sums are literal sums of the element_from_sub_weights's (though stored +# as UInt128s because any two element_from_sub_weights's will overflow when added). + +# target can also store metadata useful for compaction. +# the range 0x0000000000000001 to 0x7fffffffffffffff (1:typemax(Int)) represents literal targets +# the range 0x8000000000000001 to 0x80000000000007fe indicates that this is an empty but non-abandoned group with exponent target-0x8000000000000000 +# the range 0xc000000000000000 to 0xffffffffffffffff indicates that the group is abandoned and has length -target. + +## Initial API: + +# setindex!, getindex, resize! (auto-zeros), scalar rand +# Trivial extensions: +# push!, delete! + +Base.rand(rng::AbstractRNG, w::Weights) = _rand(rng, w.m) +Base.getindex(w::Weights, i::Int) = _getindex(w.m, i) +Base.setindex!(w::Weights, v, i::Int) = (_setindex!(w.m, Float64(v), i); w) + +#=@inbounds=# function _rand(rng::AbstractRNG, m::Memory{UInt64}) + + @label reject + + # Select level + x = @inline rand(rng, Random.Sampler(rng, Base.OneTo(m[4]), Val(1))) + + i = _convert(Int, m[2]) + mi = m[i] + if mi == 0 + i = set_last_nonzero_level_decrease!(m, i) + mi = m[i] end - firstc = ns.track_info.firstchanged - ss.p[1] = p[1] - f = firstc + Int(firstc == 1) - @inbounds for i in f:lastfull - ss.p[i] = ss.p[i-1] + p[i] + + @inbounds while i > 5 + x <= mi && break + x -= mi + i -= 1 + mi = m[i] end - ns.track_info.firstchanged = lastfull - return ss -end -function reorder_levels(ns, ss, p, lastfull) - sortperm!(@view(ss.o[1:lastfull]), @view(p[1:lastfull]); alg=Base.Sort.InsertionSortAlg()) - @inbounds for i in 1:lastfull - if ss.o[i] == zero(Int16) - all_index = ns.level_set_map.indices[ns.sampled_level_numbers[i]+1075][1] - ns.level_set_map.indices[ns.sampled_level_numbers[i]+1075] = (all_index, i) - continue - end - value1 = ns.sampled_levels[i] - value2 = ns.sampled_level_numbers[i] - value3 = p[i] - x, y = i, Int(ss.o[i]) - while y != i - ss.o[x] = zero(Int16) - ns.sampled_levels[x] = ns.sampled_levels[y] - ns.sampled_level_numbers[x] = ns.sampled_level_numbers[y] - p[x] = p[y] - x = y - y = Int(ss.o[x]) + + if x >= mi # mi is the weight rounded down plus 1. If they are equal than we should refine further and possibly reject. + # Low-probability rejection to improve accuracy from very close to perfect. + # This branch should typically be followed with probability < 2^-21. In cases where + # the probability is higher (i.e. m[4] < 2^32), _rand_slow_path will mutate m by + # modifying m[3] and recomputing approximate weights to increase m[4] above 2^32. + # This branch is still O(1) but constant factors don't matter except for in the case + # of repeated large swings in m[4] with calls to rand interspersed. + x > mi && error("This should be unreachable!") + if @noinline _rand_slow_path(rng, m, i) + @goto reject end - ns.sampled_levels[x] = value1 - ns.sampled_level_numbers[x] = value2 - p[x] = value3 - ss.o[x] = zero(Int16) - all_index = ns.level_set_map.indices[ns.sampled_level_numbers[i]+1075][1] - ns.level_set_map.indices[ns.sampled_level_numbers[i]+1075] = (all_index, i) end -end -mutable struct RejectionInfo - length::Int - maxw::Float64 - mask::UInt64 -end -struct RejectionSampler - data::Vector{Tuple{Int, Float64}} - track_info::RejectionInfo - RejectionSampler(i, v) = new([(i, v)], RejectionInfo(1, v, zero(UInt))) -end -function Random.rand(rng::AbstractRNG, rs::RejectionSampler, f::Function) - len = rs.track_info.length - mask = rs.track_info.mask - maxw = rs.track_info.maxw + # Lookup level info + j = 2i + 6133 + pos = m[j] + len = m[j+1] + + # Sample within level while true - u = rand(rng, UInt64) - i = Int(u & mask) - i >= len && continue - i += 1 - res, x = rs.data[i] - f(rng, u) * maxw < x && return (i, res) + r = rand(rng, UInt64) + k1 = (r>>leading_zeros(len-1)) + k2 = _convert(Int, k1<<1+pos) + # TODO for perf: delete the k1 < len check by maintaining all the out of bounds m[k2] equal to 0 + rand(rng, UInt64) < m[k2] * (k1 < len) && return Int(signed(m[k2+1])) end end -@inline randreuse(rng, u) = Float64(u >>> 11) * 0x1.0p-53 -@inline randnoreuse(rng, _) = rand(rng) -function Base.push!(rs::RejectionSampler, i, x) - len = rs.track_info.length += 1 - if len > length(rs.data) - resize!(rs.data, 2*length(rs.data)) - rs.track_info.mask = UInt(1) << (8*sizeof(len-1) - leading_zeros(len-1)) - 1 + +function _rand_slow_path(rng::AbstractRNG, m::Memory{UInt64}, i) + # shift::Int = exponent+m[3] + # significand_sum::UInt128 = ... + # weight::UInt64 = significand_sum<> (i - 1) + end + + # x is computed by rounding down at a certain level and then summing (and adding 1) + # m[4] will be computed by rounding up at a more precise level and then summing + # x could be 0 (treated as 1/2 when computing log2 with top_set_bit), composed of + # .9 + .9 + .9 + ... for up to about log2(length) levels + # meaning m[4] could be up to 2log2(length) times greater than predicted according to x2 + # if length is 2^64 than this could push m[4]'s top set bit up to 9 bits higher. + + # If, on the other hand, x was computed with significantly higher precision, then + # it could overflow if there were 2^64 elements in a weight. We could probably + # squeeze a few more bits out of this, but targeting 46 with a window of 46 to 53 is + # plenty good enough. + + m3 = unsigned(-17 - Base.top_set_bit(x) - (m2 - 4)) + + set_global_shift_increase!(m, m2, m3, m4) # TODO for perf: special case all call sites to this function to take advantage of known shift direction and/or magnitude; also try outlining + + @assert 46 <= Base.top_set_bit(m[4]) <= 53 # Could be a higher because of the rounding up, but this should never bump top set bit by more than about 8 # TODO for perf: delete end - rs.data[len] = (i, x) - maxwn = rs.track_info.maxw - rs.track_info.maxw = ifelse(x > maxwn, x, maxwn) - rs -end -Base.isempty(rs::RejectionSampler) = length(rs) == 0 # For testing only -Base.length(rs::RejectionSampler) = rs.track_info.length # For testing only -struct LinkedListSet - data::SizedVector{34, UInt64, Vector{UInt64}} - LinkedListSet() = new(zeros(UInt64, 34)) -end -Base.in(i::Int, x::LinkedListSet) = x.data[i >> 6 + 18] & (UInt64(1) << (0x3f - (i & 0x3f))) != 0 -Base.push!(x::LinkedListSet, i::Int) = (x.data[i >> 6 + 18] |= UInt64(1) << (0x3f - (i & 0x3f)); x) -Base.delete!(x::LinkedListSet, i::Int) = (x.data[i >> 6 + 18] &= ~(UInt64(1) << (0x3f - (i & 0x3f))); x) -function Base.findnext(x::LinkedListSet, i::Int) - j = i >> 6 + 18 - @inbounds y = x.data[j] << (i & 0x3f) - y != 0 && return i + leading_zeros(y) - for j2 in j+1:34 - @inbounds c = x.data[j2] - !iszero(c) && return j2 << 6 + leading_zeros(c) - 18*64 + while true # TODO for confidence: move this to a separate, documented function and add unit tests. + x = rand(rng, UInt64) + # p_stage = significand_sum << shift & ...00000.111111...64...11110000 + shift += 64 + target = (significand_sum << shift) % UInt64 + x > target && return true + x < target && return false + shift >= 0 && return false end - return -10000 end -function Base.findprev(x::LinkedListSet, i::Int) - j = i >> 6 + 18 - @inbounds y = x.data[j] >> (0x3f - i & 0x3f) - y != 0 && return i - trailing_zeros(y) - for j2 in j-1:-1:1 - @inbounds c = x.data[j2] - !iszero(c) && return j2 << 6 - trailing_zeros(c) - 17*64 - 1 + +function set_last_nonzero_level_decrease!(m, m2) + level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(m2-5) + chunk = m[level_weights_nonzero_index] + while chunk == 0 # Find the new m[2] + m2 -= 64 + level_weights_nonzero_index -= 1 + chunk = m[level_weights_nonzero_index] end - return -10000 + m2 += 63 - trailing_zeros(chunk) - level_weights_nonzero_subindex - 1 + m[2] = _convert(UInt64, m2) + return m2 +end + +function _getindex(m::Memory{UInt64}, i::Int) + @boundscheck 1 <= i <= m[1] || throw(BoundsError(_FixedSizeWeights(m), i)) + j = i + 10523 + mj = m[j] + mj == 0 && return 0.0 + pos = _convert(Int, mj >> 11) + exponent = mj & 2047 + weight = m[pos] + reinterpret(Float64, (exponent<<52) | (weight - 0x8000000000000000) >> 11) end -# ------------------------------ - -#= -Each entry is assigned a level based on its power. -We have at most min(n, 2048) levels. -# Maintain a distribution over the top N levels and ignore any lower -(or maybe the top log(n) levels and treat the rest as a single level). -For each level, maintain a distribution over the elements of that level -Also, maintain a distribution over the N most significant levels. -To facilitate updating, but unused during sampling, also maintain, -A linked list set (supports push!, delete!, in, findnext, and findprev) of levels -A pointer to the least significant tracked level (-1075 if there are fewer than N levels) -A vector that maps elements (integers) to their level and index in the level - -To sample, -draw from the distribution over the top N levels and then -draw from the distribution over the elements of that level. - -To add a new element at a given weight, -determine the level of that weight, -create a new level if needed, -add the element to the distribution of that level, -and update the distribution over the top N levels if needed. -Log the location of the new element. - -To create a new level, -Push the level into the linked list set of levels. -If the level is below the least significant tracked level, that's all. -Otherwise, update the least significant tracked level and evict an element -from the distribution over the top N levels if necessary - -To remove an element, -Lookup the location of the element -Remove the element from the distribution of its level -If the level is now empty, remove the level from the linked list set of levels -If the level is below the least significant tracked level, that's all. -Otherwise, update the least significant tracked level and add an element to the -distribution over the top N levels if possible -=# - -struct LevelMap - presence::BitVector - indices::Vector{Tuple{Int, Int}} - function LevelMap() - presence = BitVector() - resize!(presence, 2098) - fill!(presence, false) - indices = Vector{Tuple{Int, Int}}(undef, 2098) - return new(presence, indices) +function _setindex!(m::Memory, v::Float64, i::Int) + @boundscheck 1 <= i <= m[1] || throw(BoundsError(_FixedSizeWeights(m), i)) + uv = reinterpret(UInt64, v) + if uv == 0 + _set_to_zero!(m, i) + return + end + 0x0010000000000000 <= uv <= 0x7fefffffffffffff || throw(DomainError(v, "Invalid weight")) # Excludes subnormals + + # Find the entry's pos in the edit map table + j = i + 10523 + if m[j] == 0 + _set_from_zero!(m, v, i) + else + _set_nonzero!(m, v, i) end end -struct EntryInfo - presence::BitVector - indices::Vector{Int} - EntryInfo() = new(BitVector(), Int[]) +function _set_nonzero!(m, v, i) + # TODO for performance: join these two operations + _set_to_zero!(m, i) + _set_from_zero!(m, v, i) end -mutable struct TrackInfo - lastsampled_idx::Int - lastsampled_idx_out::Int - lastsampled_idx_in::Int - least_significant_sampled_level::Int # The level number of the least significant tracked level - nvalues::Int - firstchanged::Int - lastfull::Int - reset_order::Int - reset_distribution::Bool +Base.@propagate_inbounds function get_significand_sum(m, i) + i = _convert(Int, 2i+2041) + significand_sum = UInt128(m[i]) | (UInt128(m[i+1]) << 64) +end +function update_significand_sum(m, i, delta) + j = _convert(Int, 2i+2041) + significand_sum = get_significand_sum(m, i) + delta + m[j] = significand_sum % UInt64 + m[j+1] = (significand_sum >>> 64) % UInt64 + significand_sum end -@inline sig(x::Float64) = (reinterpret(UInt64, x) & Base.significand_mask(Float64)) + Base.significand_mask(Float64) + 1 +function _set_from_zero!(m::Memory, v::Float64, i::Int) + uv = reinterpret(UInt64, v) + j = i + 10523 + @assert m[j] == 0 -@inline function flot(sg::UInt128, level::Integer) - shift = Int64(8 * sizeof(sg) - 53 - leading_zeros(sg)) - x = (sg >>= shift) % UInt64 - exp = level + shift + 1022 - reinterpret(Float64, x + (exp << 52)) -end + exponent = uv >> 52 + # update group total weight and total weight + significand = 0x8000000000000000 | uv << 11 + weight_index = _convert(Int, exponent + 4) + significand_sum = update_significand_sum(m, weight_index, significand) # Temporarily break the "weights are accurately computed" invariant -struct NestedSampler{N} - # Used in sampling - distribution_over_levels::SelectionSampler{N} # A distribution over 1:N - sampled_levels::MVector{N, Int16} # The top up to N levels indices - all_levels::Vector{Tuple{UInt128, RejectionSampler}} # All the levels, in insertion order, along with their total weights - - # Not used in sampling - sampled_level_weights::MVector{N, Float64} # The weights of the top up to N levels - sampled_level_numbers::MVector{N, Int16} # The level numbers of the top up to N levels - level_set::LinkedListSet # A set of which levels are non-empty (named by level number) - level_set_map::LevelMap # A mapping from level number to index in all_levels and index in sampled_levels (or 0 if not in sampled_levels) - entry_info::EntryInfo # A mapping from element to level number and index in that level (index in level is 0 if entry is not present) - track_info::TrackInfo -end + if m[4] == 0 # if we were empty, set global shift (m[3]) so that m[4] will become ~2^40. + m[3] = -24 - exponent -NestedSampler() = NestedSampler{64}() -NestedSampler{N}() where N = NestedSampler{N}( - SelectionSampler(zero(MVector{N, Float64}), MVector{N, Int16}(1:N)), - zero(MVector{N, Int16}), - Tuple{UInt128, RejectionSampler}[], - zero(MVector{N, Float64}), - zero(MVector{N, Int16}), - LinkedListSet(), - LevelMap(), - EntryInfo(), - TrackInfo(0, 0, 0, -1075, 0, 1, 0, 0, true), -) - -Base.rand(ns::NestedSampler, n::Integer) = rand(Random.default_rng(), ns, n) -function Base.rand(rng::AbstractRNG, ns::NestedSampler, n::Integer) - n < 100 && return [rand(rng, ns) for _ in 1:n] - lastfull = ns.track_info.lastfull - ws = @view(ns.sampled_level_weights[1:lastfull]) - totw = sum(ws) - maxw = maximum(ws) - maxw/totw > 0.98 && return [rand(rng, ns) for _ in 1:n] - n_each = rand(rng, Multinomial(n, ws ./ totw)) - inds = Vector{Int}(undef, n) - q = 1 - @inbounds for (level, k) in enumerate(n_each) - bucket = ns.all_levels[Int(ns.sampled_levels[level])][2] - f = length(bucket) <= 2048 ? randreuse : randnoreuse - for _ in 1:k - ti = @inline rand(rng, bucket, f) - inds[q] = ti[2] - q += 1 + shift = -24 + weight = _convert(UInt64, significand_sum << shift) + 1 + + @assert Base.top_set_bit(weight-1) == 40 # TODO for perf: delete + m[weight_index] = weight + m[4] = weight + else + shift = signed(exponent + m[3]) + if Base.top_set_bit(significand_sum)+shift > 64 + # if this would overflow, drop shift so that it renormalizes down to 48. + # this drops shift at least ~16 and makes the sum of weights at least ~2^48. # TODO: add an assert + # Base.top_set_bit(significand_sum)+shift == 48 + # Base.top_set_bit(significand_sum)+signed(exponent + m[3]) == 48 + # Base.top_set_bit(significand_sum)+signed(exponent) + signed(m[3]) == 48 + # signed(m[3]) == 48 - Base.top_set_bit(significand_sum) - signed(exponent) + m3 = 48 - Base.top_set_bit(significand_sum) - exponent + # The "weights are accurately computed" invariant is broken for weight_index, but the "sum(weights) == m[4]" invariant still holds + # set_global_shift_decrease! will do something wrong to weight_index, but preserve the "sum(weights) == m[4]" invariant. + set_global_shift_decrease!(m, m3) # TODO for perf: special case all call sites to this function to take advantage of known shift direction and/or magnitude; also try outlining + shift = signed(exponent + m3) + end + weight = _convert(UInt64, significand_sum << shift) + 1 + + old_weight = m[weight_index] + m[weight_index] = weight # The "weights are accurately computed" invariant is now restored + m4 = m[4] # The "sum(weights) == m[4]" invariant is broken + m4 -= old_weight + m4, o = Base.add_with_overflow(m4, weight) # The "sum(weights) == m4" invariant now holds, though the computation overflows + if o + # If weights overflow (>2^64) then shift down by 16 bits + m3 = m[3]-0x10 + set_global_shift_decrease!(m, m3, m4) # TODO for perf: special case all call sites to this function to take advantage of known shift direction and/or magnitude; also try outlining + if weight_index > m[2] # if the new weight was not adjusted by set_global_shift_decrease!, then adjust it manually + shift = signed(exponent+m3) + new_weight = _convert(UInt64, significand_sum << shift) + 1 + + @assert significand_sum != 0 + @assert m[weight_index] == weight + + m[weight_index] = new_weight + m[4] += new_weight-weight + end + else + m[4] = m4 end end - shuffle!(rng, inds) - return inds -end -Base.rand(ns::NestedSampler) = rand(Random.default_rng(), ns) -@inline function Base.rand(rng::AbstractRNG, ns::NestedSampler) - track_info = ns.track_info - track_info.reset_order += 1 - lastfull = track_info.lastfull - reorder = lastfull > 8 && track_info.reset_order > 300*lastfull - if track_info.reset_distribution || reorder - @inline set_cum_weights!(ns.distribution_over_levels, ns, reorder) - track_info.reset_distribution = false + m[2] = max(m[2], weight_index) # Set after insertion because update_weights! may need to update the global shift, in which case knowing the old m[2] will help it skip checking empty levels + level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(exponent) + m[level_weights_nonzero_index] |= 0x8000000000000000 >> level_weights_nonzero_subindex + + # lookup the group by exponent and bump length + group_length_index = _convert(Int, 4 + 3*2046 + 2exponent) + group_pos = m[group_length_index-1] + group_length = m[group_length_index]+1 + m[group_length_index] = group_length # setting this before compaction means that compaction will ensure there is enough space for this expanded group, but will also copy one index (16 bytes) of junk which could access past the end of m. The junk isn't an issue once coppied because we immediately overwrite it. The former (copying past the end of m) only happens if the group to be expanded is already kissing the end. In this case, it will end up at the end after compaction and be easily expanded afterwords. Consequently, we treat that case specially and bump group length and manually expand after compaction + allocs_index,allocs_subindex = get_alloced_indices(exponent) + allocs_chunk = m[allocs_index] + log2_allocated_size = allocs_chunk >> allocs_subindex % UInt8 - 1 + allocated_size = 1< allocated_size + next_free_space = m[10267] + # if at end already, simply extend the allocation # TODO see if removing this optimization is problematic; TODO verify the optimization is triggering + if next_free_space == (group_pos-2)+2group_length # note that this is valid even if group_length is 1 (previously zero). + new_allocation_length = max(2, 2allocated_size) + new_next_free_space = next_free_space+new_allocation_length + if new_next_free_space > length(m)+1 # There isn't room; we need to compact + m[group_length_index] = group_length-1 # See comment above; we don't want to copy past the end of m + next_free_space = compact!(m, m) + group_pos = next_free_space-new_allocation_length # The group will move but remian the last group + new_next_free_space = next_free_space+new_allocation_length + @assert new_next_free_space < length(m)+1 # TODO for perf, delete this + m[group_length_index] = group_length + + # Re-lookup allocated chunk because compaction could have changed other + # chunk elements. However, the allocated size of this group could not have + # changed because it was previously maxed out. + allocs_chunk = m[allocs_index] + @assert log2_allocated_size == allocs_chunk >> allocs_subindex % UInt8 - 1 + @assert allocated_size == 1< length(m)+1 # out of space; compact. TODO for perf, consider resizing at this time slightly eagerly? + m[group_length_index] = group_length-1 # incrementing the group length before compaction is spotty because if the group was previously empty then this new group length will be ignored (compact! loops over sub_weights, not levels) + next_free_space = compact!(m, m) + m[group_length_index] = group_length + new_next_free_space = next_free_space+twice_new_allocated_size + @assert new_next_free_space < length(m)+1 # After compaction there should be room TODO for perf, delete this + + group_pos = m[group_length_index-1] # The group likely moved during compaction + + # Re-lookup allocated chunk because compaction could have changed other + # chunk elements. However, the allocated size of this group could not have + # changed because it was previously maxed out. + allocs_chunk = m[allocs_index] + @assert log2_allocated_size == allocs_chunk >> allocs_subindex % UInt8 - 1 + @assert allocated_size == 1< l_info - newl = max(2*l_info, maxi) - resize!(ns.entry_info.indices, newl) - resize!(ns.entry_info.presence, newl) - fill!(@view(ns.entry_info.presence[l_info+1:newl]), false) +function set_global_shift_increase!(m::Memory, m2, m3::UInt64, m4) # Increase shift, on deletion of elements + @assert signed(m[3]) < signed(m3) + m[3] = m3 + # Story: + # In the likely case that the weight decrease resulted in a level's weight hitting zero + # that level's weight is already updated and m4 adjusted accordingly TODO for perf don't adjust, pass the values around instead + # In any event, m4 is accurate for current weights and all weights and significand_sums's above (before) m2 are zero so we don't need to touch them + # Between m2 and i1, weights that were previously 1 may need to be increased. Below (past, after) i1, all weights will round up to 1 or 0 so we don't need to touch them + + #= + weight = UInt64(significand_sum< l_info - newl = max(2*l_info, i) - resize!(ns.entry_info.indices, newl) - resize!(ns.entry_info.presence, newl) - fill!(@view(ns.entry_info.presence[l_info+1:newl]), false) - elseif ns.entry_info.presence[i] - throw(ArgumentError("Element $i is already present")) +function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease shift, on insertion of elements + + m2 = _convert(Int, m[2]) + if m[m2] == 0 + m2 = set_last_nonzero_level_decrease!(m, m2) end - return _push!(ns, i, x) -end -@inline function _push!(ns::NestedSampler{N}, i::Int, x::Float64) where N - bucketw, level = frexp(x) - level -= 1 - level_b16 = Int16(level) - ns.entry_info.presence[i] = true - if level ∉ ns.level_set - # Log the entry - ns.entry_info.indices[i] = 4096 + level + 1075 - - # Create a new level (or revive an empty level) - push!(ns.level_set, level) - existing_level_indices = ns.level_set_map.presence[level+1075] - all_levels_index = if !existing_level_indices - level_sampler = RejectionSampler(i, bucketw) - push!(ns.all_levels, (sig(x), level_sampler)) - length(ns.all_levels) - else - level_indices = ns.level_set_map.indices[level+1075] - w, level_sampler = ns.all_levels[level_indices[1]] - @assert w == 0 - @assert isempty(level_sampler) - push!(level_sampler, i, bucketw) - ns.all_levels[level_indices[1]] = (sig(x), level_sampler) - level_indices[1] - end - ns.level_set_map.presence[level+1075] = true - - # Update the sampled levels if needed - if level > ns.track_info.least_significant_sampled_level # we just created a sampled level - if ns.track_info.lastfull < N # Add the new level to the top 64 - ns.track_info.lastfull += 1 - sl_length = ns.track_info.lastfull - ns.sampled_levels[sl_length] = Int16(all_levels_index) - ns.sampled_level_weights[sl_length] = x - ns.sampled_level_numbers[sl_length] = level_b16 - ns.level_set_map.indices[level+1075] = (all_levels_index, sl_length) - if sl_length == N - ns.track_info.least_significant_sampled_level = findnext(ns.level_set, ns.track_info.least_significant_sampled_level+1) - end - else # Replace the least significant sampled level with the new level - j, k = ns.level_set_map.indices[ns.track_info.least_significant_sampled_level+1075] - ns.level_set_map.indices[ns.track_info.least_significant_sampled_level+1075] = (j, 0) - ns.sampled_levels[k] = Int16(all_levels_index) - ns.sampled_level_weights[k] = x - ns.sampled_level_numbers[k] = level_b16 - ns.level_set_map.indices[level+1075] = (all_levels_index, k) - ns.track_info.least_significant_sampled_level = findnext(ns.level_set, ns.track_info.least_significant_sampled_level+1) - firstc = ns.track_info.firstchanged - ns.track_info.firstchanged = ifelse(k < firstc, k, firstc) - end - else # created an unsampled level - ns.level_set_map.indices[level+1075] = (all_levels_index, 0) - end - else # Add to an existing level - j, k = ns.level_set_map.indices[level+1075] - w, level_sampler = ns.all_levels[j] - push!(level_sampler, i, bucketw) - ns.entry_info.indices[i] = length(level_sampler) << 12 + level + 1075 - wn = w+sig(x) - ns.all_levels[j] = (wn, level_sampler) - - if k != 0 # level is sampled - ns.sampled_level_weights[k] = flot(wn, level) - firstc = ns.track_info.firstchanged - ns.track_info.firstchanged = ifelse(k < firstc, k, firstc) - end + m3_old = m[3] + m[3] = m3 + @assert signed(m3) < signed(m3_old) + + # In the case of adding a giant element, call this first, then add the element. + # In any case, this only adjusts elements at or before m[2] + # from the first index that previously could have had a weight > 1 to min(m[2], the first index that can't have a weight > 1) (never empty), set weights to 1 or 0 + # from the first index that could have a weight > 1 to m[2] (possibly empty), shift weights by delta. + i1 = -signed(m3)-117 # see above, this is the first index that could have weight > 1 (anything after this will have weight 1 or 0) + i1_old = -signed(m3_old)-117 # anything before this is already weight 1 or 0 + flatten_range = max(i1_old, 5):min(m2, i1-1) + recompute_range = max(i1, 5):m2 + # From the level where one element contributes 2^64 to the level where one element contributes 1 is 64, and from there to the level where 2^64 elements contributes 1 is another 2^64. + @assert length(flatten_range) <= 128 + @assert length(recompute_range) <= 128 + + checkbounds(m, flatten_range) + @inbounds for i in flatten_range # set nonzeros to 1 + old_weight = m[i] + weight = old_weight != 0 + m[i] = weight + m4 += weight-old_weight + end + + delta = m3_old-m3 + checkbounds(m, recompute_range) + @inbounds for i in recompute_range + old_weight = m[i] + old_weight <= 1 && continue # in this case, the weight was and still is 0 or 1 + m4 += update_weight!(m, i, (old_weight-1) >> delta) end - return ns + + m[4] = m4 end -@inline function Base.delete!(ns::NestedSampler, i::Int) - ns_track_info = ns.track_info - ns_track_info.reset_distribution = true - ns_track_info.reset_order += 1 - ns_track_info.nvalues -= 1 - if i <= 0 || i > lastindex(ns.entry_info.presence) - throw(ArgumentError("Element $i is not present")) +Base.@propagate_inbounds function update_weight!(m::Memory{UInt64}, i, shifted_significand_sum) + weight = _convert(UInt64, shifted_significand_sum) + 1 + old_weight = m[i] + m[i] = weight + weight-old_weight +end + +get_alloced_indices(exponent::UInt64) = _convert(Int, 10268 + exponent >> 3), exponent << 3 & 0x38 +get_level_weights_nonzero_indices(exponent) = _convert(Int, 10235 + exponent >> 6), exponent & 0x3f + +function _set_to_zero!(m::Memory, i::Int) + # Find the entry's pos in the edit map table + j = i + 10523 + mj = m[j] + mj == 0 && return # if the entry is already zero, return + pos = _convert(Int, mj >> 11) + exponent = mj & 2047 + # set the entry to zero (no need to zero the exponent) + # m[j] = 0 is moved to after we adjust the edit_map entry for the shifted element, in case there is no shifted element + + # update group total weight and total weight + significand = m[pos] + weight_index = _convert(Int, exponent + 4) + significand_sum = update_significand_sum(m, weight_index, -UInt128(significand)) + old_weight = m[weight_index] + m4 = m[4] + m4 -= old_weight + if significand_sum == 0 # We zeroed out a group + m[weight_index] = 0 + level_weights_nonzero_index,level_weights_nonzero_subindex = get_level_weights_nonzero_indices(exponent) + m[level_weights_nonzero_index] &= ~(0x8000000000000000 >> level_weights_nonzero_subindex) + else # We did not zero out a group + shift = signed(exponent + m[3]) + new_weight = _convert(UInt64, significand_sum << shift) + 1 + m[weight_index] = new_weight + m4 += new_weight end - if ns_track_info.lastsampled_idx == i - level = Int(ns.sampled_level_numbers[ns_track_info.lastsampled_idx_out]) - j = ns_track_info.lastsampled_idx_in - else - c = ns.entry_info.indices[i] - level = c & 4095 - 1075 - j = (c - level - 1075) >> 12 + + m[4] = m4 # This might be less than 2^32, but that's okay. If it is, and that's relevant, it will be corrected in _rand_slow_path + + # lookup the group by exponent + group_length_index = _convert(Int, 4 + 3*2046 + 2exponent) + group_pos = m[group_length_index-1] + group_length = m[group_length_index] + group_lastpos = _convert(Int, (group_pos-2)+2group_length) + + # TODO for perf: see if it's helpful to gate this on pos != group_lastpos + # shift the last element of the group into the spot occupied by the removed element + m[pos] = m[group_lastpos] + shifted_element = m[pos+1] = m[group_lastpos+1] + + # adjust the edit map entry of the shifted element + m[_convert(Int, shifted_element) + 10523] = _convert(UInt64, pos) << 11 + exponent + m[j] = 0 + + # When zeroing out a group, mark the group as empty so that compaction will update the group metadata and then skip over it. + if significand_sum == 0 + m[group_pos+1] = exponent | 0x8000000000000000 end - ns_track_info.lastsampled_idx = 0 - !ns.entry_info.presence[i] && throw(ArgumentError("Element $i is not present")) - ns.entry_info.presence[i] = false - - l, k = ns.level_set_map.indices[level+1075] - w, level_sampler = ns.all_levels[l] - _i, significand = level_sampler.data[j] - @assert _i == i - len = level_sampler.track_info.length - moved_entry, _ = level_sampler.data[j] = level_sampler.data[len] - level_sampler.track_info.length -= 1 - if (len & (len-1)) == 0 - level_sampler.track_info.mask = UInt(1) << (8*sizeof(len-1) - leading_zeros(len-1)) - 1 + + # shrink the group + m[group_length_index] = group_length-1 # no need to zero group entries + + nothing +end + + +ResizableWeights(len::Integer) = ResizableWeights(FixedSizeWeights(len)) +SemiResizableWeights(len::Integer) = SemiResizableWeights(FixedSizeWeights(len)) +function FixedSizeWeights(len::Integer) + m = Memory{UInt64}(undef, allocated_memory(len)) + # m .= 0 # This is here so that a sparse rendering for debugging is easier TODO for tests: set this to 0xdeadbeefdeadbeed + m[4:10523+len] .= 0 # metadata and edit map need to be zeroed but the bulk does not + m[1] = len + m[2] = 4 + # no need to set m[3] + m[10267] = 10524+len + _FixedSizeWeights(m) +end +allocated_memory(length::Integer) = 10523 + 7*length # TODO for perf: consider giving some extra constant factor allocation to avoid repeated compaction at small sizes +length_from_memory(allocated_memory::Integer) = Int((allocated_memory-10523)/7) + +Base.resize!(w::Union{SemiResizableWeights, ResizableWeights}, len::Integer) = resize!(w, Int(len)) +function Base.resize!(w::Union{SemiResizableWeights, ResizableWeights}, len::Int) + m = w.m + old_len = m[1] + if len > old_len + am = allocated_memory(len) + if am > length(m) + w isa SemiResizableWeights && throw(ArgumentError("Cannot increase the size of a SemiResizableWeights above its original allocated size. Try using a ResizableWeights instead.")) + _resize!(w, len) + else + m[1] = len + end + else + w[len+1:old_len] .= 0 # This is a necessary but highly nontrivial operation + m[1] = len end - if moved_entry != i - @assert ns.entry_info.indices[moved_entry] == (length(level_sampler)+1) << 12 + level + 1075 - ns.entry_info.indices[moved_entry] = j << 12 + level + 1075 + w +end +""" +Reallocate w with the size len, compacting w into that new memory. +Any elements if w past len must be set to zero already (that's a general invariant for +Weights, though, not just this function). +""" +function _resize!(w::ResizableWeights, len::Integer) + m = w.m + old_len = m[1] + m2 = Memory{UInt64}(undef, allocated_memory(len)) + # m2 .= 0 # For debugging; TODO: set to 0xdeadbeefdeadbeef to test + m2[1] = len + if len > old_len # grow + unsafe_copyto!(m2, 2, m, 2, old_len + 10523) + m2[old_len + 10524:len + 10523] .= 0 + else # shrink + unsafe_copyto!(m2, 2, m, 2, len + 10523) end - wn = w-sig(significand*exp2(level+1)) - ns.all_levels[l] = (wn, level_sampler) - - if isempty(level_sampler) # Remove a level - delete!(ns.level_set, level) - if k != 0 # Remove a sampled level - firstc = ns.track_info.firstchanged - ns.track_info.firstchanged = ifelse(k < firstc, k, firstc) - replacement = findprev(ns.level_set, ns_track_info.least_significant_sampled_level-1) - ns.level_set_map.indices[level+1075] = (l, 0) - if replacement == -10000 # We'll now have fewer than N sampled levels - ns_track_info.least_significant_sampled_level = -1075 - sl_length = ns_track_info.lastfull - ns_track_info.lastfull -= 1 - moved_level = ns.sampled_level_numbers[sl_length] - if moved_level == Int16(level) - ns.sampled_level_weights[sl_length] = 0.0 - else - ns.sampled_level_numbers[k], ns.sampled_level_numbers[sl_length] = ns.sampled_level_numbers[sl_length], ns.sampled_level_numbers[k] - ns.sampled_levels[k], ns.sampled_levels[sl_length] = ns.sampled_levels[sl_length], ns.sampled_levels[k] - ns.sampled_level_weights[k] = ns.sampled_level_weights[sl_length] - ns.sampled_level_weights[sl_length] = 0.0 - all_index, _l = ns.level_set_map.indices[ns.sampled_level_numbers[k]+1075] - @assert _l == ns.track_info.lastfull+1 - ns.level_set_map.indices[ns.sampled_level_numbers[k]+1075] = (all_index, k) - all_index = ns.level_set_map.indices[ns.sampled_level_numbers[sl_length]+1075][1] - ns.level_set_map.indices[ns.sampled_level_numbers[sl_length]+1075] = (all_index, sl_length) - end - else # Replace the removed level with the replacement - ns_track_info.least_significant_sampled_level = replacement - all_index, _zero = ns.level_set_map.indices[replacement+1075] - @assert _zero == 0 - ns.level_set_map.indices[replacement+1075] = (all_index, k) - w, replacement_level = ns.all_levels[all_index] - ns.sampled_levels[k] = Int16(all_index) - ns.sampled_level_weights[k] = flot(w, replacement) - ns.sampled_level_numbers[k] = replacement + + compact!(m2, m) + w.m = m2 + w +end + +function compact!(dst::Memory{UInt64}, src::Memory{UInt64}) + dst_i = length_from_memory(length(dst)) + 10524 + src_i = length_from_memory(length(src)) + 10524 + next_free_space = src[10267] + + while src_i < next_free_space + + # Skip over abandoned groups TODO refactor these loops for clarity + target = signed(src[src_i+1]) + while target < 0 + if unsigned(target) < 0xc000000000000000 # empty non-abandoned group; let's clean it up + @assert 0x8000000000000001 <= unsigned(target) <= 0x80000000000007fe + exponent = unsigned(target) - 0x8000000000000000 # TODO for clarity: dry this + allocs_index, allocs_subindex = get_alloced_indices(exponent) + allocs_chunk = dst[allocs_index] # TODO for perf: consider not copying metadata on out of place compaction (and consider the impact here) + log2_allocated_size_p1 = allocs_chunk >> allocs_subindex % UInt8 + allocated_size = 1<<(log2_allocated_size_p1-1) + new_chunk = allocs_chunk - UInt64(log2_allocated_size_p1) << allocs_subindex + dst[allocs_index] = new_chunk # zero out allocated size (this will force re-allocation so we can let the old, wrong pos info stand) + src_i += 2allocated_size # skip the group + else # the decaying corpse of an abandoned group. Ignore it. + src_i -= 2target end + src_i >= next_free_space && @goto break_outer + target = signed(src[src_i+1]) end - elseif k != 0 - ns.sampled_level_weights[k] = flot(wn, level) - firstc = ns.track_info.firstchanged - ns.track_info.firstchanged = ifelse(k < firstc, k, firstc) + + # Trace an element of the group back to the edit info table to find the group id + j = target + 10523 + exponent = src[j] & 2047 + + # Lookup the group in the group location table to find its length (performance optimization for copying, necessary to decide new allocated size and update pos) + # exponent of 0x00000000000007fe is index 6+3*2046 + # exponent of 0x0000000000000001 is index 4+5*2046 + group_length_index = _convert(Int, 4 + 3*2046 + 2exponent) + group_length = src[group_length_index] + + # Update group pos in level_location_info + dst[group_length_index-1] += unsigned(Int64(dst_i-src_i)) + + # Lookup the allocated size (an alternative to scanning for the next nonzero, needed because we are setting allocated size) + # exponent of 0x00000000000007fe is index 6+5*2046, 2 + # exponent of 0x00000000000007fd is index 6+5*2046, 1 + # exponent of 0x0000000000000004 is index 5+5*2046+512, 0 + # exponent of 0x0000000000000003 is index 5+5*2046+512, 3 + # exponent of 0x0000000000000002 is index 5+5*2046+512, 2 + # exponent of 0x0000000000000001 is index 5+5*2046+512, 1 + allocs_index, allocs_subindex = get_alloced_indices(exponent) + allocs_chunk = dst[allocs_index] + log2_allocated_size = allocs_chunk >> allocs_subindex % UInt8 - 1 + log2_new_allocated_size = group_length == 0 ? 0 : Base.top_set_bit(group_length-1) + new_chunk = allocs_chunk + Int64(log2_new_allocated_size - log2_allocated_size) << allocs_subindex + dst[allocs_index] = new_chunk + + # Adjust the pos entries in edit_map (bad memory order TODO: consider unzipping edit map to improve locality here) + delta = unsigned(Int64(dst_i-src_i)) << 11 + dst[j] += delta + for k in 1:signed(group_length)-1 # TODO: add a benchmark that stresses compaction and try hoisting this bounds checking + target = src[src_i+2k+1] + j = _convert(Int, target + 10523) + dst[j] += delta + end + + # Copy the group to a compacted location + unsafe_copyto!(dst, dst_i, src, src_i, 2group_length) + + # Advance indices + src_i += 2*1< length(wbs.w) && resize!(wbs.w, max(index, 2length(wbs.w))) + wbs.w[index] = weight + wbs end -function SamplerIndices(ns::NestedSampler) - iter = Iterators.Flatten((Iterators.map(x -> x[1], @view(b[2].data[1:b[2].track_info.length])) for b in ns.all_levels)) - SamplerIndices(ns, iter) +function Base.append!(wbs::WeightBasedSampler, inds::AbstractVector, weights::AbstractVector) + axes(inds) == axes(weights) || throw(DimensionMismatch("inds and weights have different axes")) + min_ind,max_ind = extrema(inds) + min_ind < 1 && throw(BoundsError(wbs.w, min_ind)) + max_ind > length(wbs.w) && resize!(wbs.w, max(max_ind, 2length(wbs.w))) + for (i,w) in zip(inds, weights) + wbs.w[i] = w + end + wbs end -Base.iterate(inds::SamplerIndices) = Base.iterate(inds.iter) -Base.iterate(inds::SamplerIndices, state) = Base.iterate(inds.iter, state) -Base.eltype(::Type{<:SamplerIndices}) = Int -Base.IteratorSize(::Type{<:SamplerIndices}) = Base.HasLength() -Base.length(inds::SamplerIndices) = inds.ns.track_info.nvalues +function Base.delete!(wbs::WeightBasedSampler, index) + index ∈ eachindex(wbs.w) && wbs.w[index] != 0 || throw(ArgumentError("Element $index is not present")) + wbs.w[index] = 0 + wbs +end +Base.rand(rng::AbstractRNG, wbs::WeightBasedSampler) = rand(rng, wbs.w) +Base.rand(rng::AbstractRNG, wbs::WeightBasedSampler, n::Integer) = [rand(rng, wbs.w) for _ in 1:n] + +const DynamicDiscreteSampler = WeightBasedSampler -const DynamicDiscreteSampler = NestedSampler +# Precompile +precompile(WeightBasedSampler, ()) +precompile(push!, (WeightBasedSampler, Int, Float64)) +precompile(delete!, (WeightBasedSampler, Int)) +precompile(rand, (typeof(Random.default_rng()), WeightBasedSampler)) +precompile(rand, (WeightBasedSampler,)) end diff --git a/test/invariants.jl b/test/invariants.jl new file mode 100644 index 00000000..872be619 --- /dev/null +++ b/test/invariants.jl @@ -0,0 +1,55 @@ +isdefined(@__MODULE__, :Memory) || const Memory = Vector # Compat for Julia < 1.11 +_get_UInt128(m::Memory, i::Integer) = UInt128(m[i]) | (UInt128(m[i+1]) << 64) +_length_from_memory(allocated_memory::Integer) = Int((allocated_memory-10523)/7) +function verify_weights(w::DynamicDiscreteSamplers.Weights) + m = w.m + m3 = m[3] + for i in 5:2050 + shift = signed(i - 4 + m3) + weight = m[i] + shifted_significand_sum_index = 2041 + 2i + shifted_significand_sum = _get_UInt128(m, shifted_significand_sum_index) + expected_weight = UInt64(shifted_significand_sum<= findlast(i -> i == 4 || m[i] != 0, 1:2050) + if m[4] != 0 + rand(w) + @assert m[2] == findlast(i -> m[i] != 0, 1:2050) + end +end +function verify_m4(w::DynamicDiscreteSamplers.Weights) + m = w.m + m4 = zero(UInt64) + for i in 5:2050 + m4 = Base.checked_add(m4, m[i]) + end + @assert m[4] == m4 + # @assert m4 == 0 || UInt64(2)^32 <= m4 # This invariant is now maintained loosely and lazily +end + +function verify_edit_map_points_to_correct_target(w::DynamicDiscreteSamplers.Weights) + m = w.m + filled_len = m[1] + len = _length_from_memory(length(m)) + for i in 1:len + edit_map_entry = m[i+10523] + if i > filled_len + @assert edit_map_entry == 0 + elseif edit_map_entry != 0 + @assert m[edit_map_entry>>11 + 1] == i + end + end +end + +function verify(w::DynamicDiscreteSamplers.Weights) + verify_weights(w) + verify_m2(w) + verify_m4(w) + verify_edit_map_points_to_correct_target(w) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1606ab81..06a6119d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,13 +7,7 @@ using Random using StableRNGs using StatsBase -@testset "unit tests" begin - lls = DynamicDiscreteSamplers.LinkedListSet() - push!(lls, 2) - push!(lls, 3) - delete!(lls, 2) - @test 3 in lls -end +@test DynamicDiscreteSamplers.DEBUG === true @testset "basic end to end tests" begin ds = DynamicDiscreteSampler() @@ -84,21 +78,14 @@ end end @testset "Targeted statistical tests" begin - #Issue 8 - for N in [1, 2, 4, 64, 128] - ds = DynamicDiscreteSampler{N}() - for i in 1:3 - push!(ds, i, float(i)) - end - delete!(ds, 2) - if N > 1 - @test 0 < count(rand(ds) == 1 for _ in 1:4000) < 1200 # False positivity rate < 4e-13 - else - @test count(rand(ds) == 1 for _ in 1:4000) == 0 - end + ds = DynamicDiscreteSampler() + for i in 1:3 + push!(ds, i, float(i)) end + delete!(ds, 2) + @test 0 < count(rand(ds) == 1 for _ in 1:4000) < 1200 # False positivity rate < 4e-13 end - + @testset "Randomized statistical tests" begin rng = StableRNG(42) b = 100 @@ -136,7 +123,7 @@ end @test pvalue(chisq_test) > 0.002 ds2 = DynamicDiscreteSampler() - + append!(ds2, range, weights) delete!(ds2, 1) @@ -195,3 +182,44 @@ if "CI" in keys(ENV) Aqua.test_deps_compat(DynamicDiscreteSamplers, check_extras=false) end end + +@testset "stress test huge probability swings" begin + ds = DynamicDiscreteSampler() + push!(ds, 1, 1e-300) + @test rand(ds) == 1 + push!(ds, 2, 1e300) + @test rand(ds) == 2 + delete!(ds, 2) + @test rand(ds) == 1 +end + +include("weights.jl") + +function error_d03fb() + ds = DynamicDiscreteSampler() + for i in 1:1_500 + push!(ds, i, 0.1) + end + for i in 1:25_000 + push!(ds, rand(ds), exp(8randn())) + end +end +error_d03fb() # This threw AssertionError: 48 <= Base.top_set_bit(m[4]) <= 50 90% of the time on d03fb84d1b62272c5d6ab54c49e643af9b87201b + +function error_d03fb_2(n) + w = DynamicDiscreteSamplers.FixedSizeWeights(2^n+1); + for i in 1:2^n-1 + w[i] = .99*.5^Base.top_set_bit(i) + end + w[2^n] = .99 + w[2^n+1] = 1e100 + w[2^n+1] = 0 + @test UInt64(2)^32 < w.m[3] +end +error_d03fb_2.(1:15) + +ds = DynamicDiscreteSampler() +push!(ds, 2, 1e308) +delete!(ds, 2) +push!(ds, 2, 1e308) # This previously threw +@test rand(ds) == 2 diff --git a/test/statistical.jl b/test/statistical.jl new file mode 100644 index 00000000..68201417 --- /dev/null +++ b/test/statistical.jl @@ -0,0 +1,74 @@ +using StatsFuns, Random + +# Not using HypothesisTests's ChisqTest because of https://github.com/JuliaStats/HypothesisTests.jl/issues/281 +# Not using ChisqTest at all because it reports very low p values with high probability for some distributions + +""" + statistical_test(rng, sampler, expected_probabilities, n) + +For all `p`, if `sampler` samples according to `expected_probabilities` then return a number +less than `p` with probability at most `p`. + +Also, makes an effort to return low numbers whenever `sampler` does not sample according to +`expected_probabilities`. + +Calls `rand(rng, sampler)` `n` times. + +Inspired by https://github.com/JuliaStats/Distributions.jl/blob/1e6801da6678164b13330cc1f16e670768d27330/test/testutils.jl#L99 +""" +function _statistical_test(rng, sampler, expected_probabilities, n) + sample = similar(expected_probabilities, Int) + sample .= 0 + for _ in 1:n + sample[rand(rng, sampler)] += 1 + end + + nonzeros = 0 + for i in eachindex(expected_probabilities, sample) + if iszero(expected_probabilities[i]) + sample[i] != 0 && return 0.0 + else + nonzeros += 1 + end + end + + mn = 1/nonzeros # TODO: justify why we can get away with not dividing by 2 without false positives + for i in eachindex(expected_probabilities, sample) + if !iszero(expected_probabilities[i]) + p_le = binomcdf(n, expected_probabilities[i], sample[i]) + p_ge = 1-binomcdf(n, expected_probabilities[i], sample[i]-1) + mn = min(mn, p_le, p_ge) + end + end + mn*nonzeros +end + +FALSE_POSITIVITY_ACCUMULATOR::Float64 = isdefined(@__MODULE__, :FALSE_POSITIVITY_ACCUMULATOR) ? FALSE_POSITIVITY_ACCUMULATOR : 0.0; + +function _statistical_test(rng, sampler, expected_probabilities) + global FALSE_POSITIVITY_ACCUMULATOR += 1e-8 + + p = _statistical_test(rng, sampler, expected_probabilities, 1_000) + p > .1 && return true + for _ in 1:7 + p = _statistical_test(rng, sampler, expected_probabilities, 10_000) + p > .1 && return true + end + + println(stderr, "statistical test failure") + global FAILED_SAMPLER = sampler + global FAILED_EXPECTED_PROBABILITIES = expected_probabilities + if isinteractive() + println("reproduce with `statistical_test(FAILED_SAMPLER, FAILED_EXPECTED_PROBABILITIES)`") + else + @show sampler expected_probabilities + @show sampler.m + end + false +end + +function statistical_test(rng, sampler, expected_probabilities) + @test _statistical_test(rng, sampler, expected_probabilities) +end +statistical_test(sampler, expected_probabilities) = + statistical_test(Random.default_rng(), sampler, expected_probabilities) diff --git a/test/weights.jl b/test/weights.jl new file mode 100644 index 00000000..28b68dd8 --- /dev/null +++ b/test/weights.jl @@ -0,0 +1,357 @@ +using DynamicDiscreteSamplers, Test + +@test DynamicDiscreteSamplers.FixedSizeWeights(10) isa DynamicDiscreteSamplers.FixedSizeWeights +@test DynamicDiscreteSamplers.ResizableWeights(10) isa DynamicDiscreteSamplers.ResizableWeights +@test DynamicDiscreteSamplers.SemiResizableWeights(10) isa DynamicDiscreteSamplers.SemiResizableWeights + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) + +@test_throws ArgumentError("collection must be non-empty") rand(w) + +@test 1 === (w[1] = 1) + +@test rand(w) === 1 + +@test_throws BoundsError w[0] +@test_throws BoundsError w[11] +@test w[1] === 1.0 +for i in 2:10 + @test w[i] === 0.0 +end + +@test 0 === (w[1] = 0) +@test w[1] === 0.0 + +@test_throws ArgumentError("collection must be non-empty") rand(w) + +w[1] = 1.5 +@test w[1] === 1.5 + +w[1] = 2 +@test w[1] === 2.0 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[1] = 3 +w[2] = 2 +w[3] = 3 +@test w[1] == 3 +@test w[2] == 2 +@test w[3] == 3 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[9] = 3 +w[7] = 3 +w[1] = 3 +@test w[9] == 3 +@test w[7] == 3 +@test w[1] == 3 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[8] = 0.549326222415666 +w[6] = 1.0149666786255531 +w[3] = 0.8210275222825218 +@test w[8] === 0.549326222415666 +@test w[6] === 1.0149666786255531 +@test w[3] === 0.8210275222825218 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[8] = 3.2999782300326728 +w[9] = 0.7329714939310719 +w[3] = 2.397108987310203 +@test w[8] === 3.2999782300326728 +@test w[9] === 0.7329714939310719 +@test w[3] === 2.397108987310203 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[1] = 1.5 +w[2] = 1.6 +w[1] = 1.7 +@test w[1] === 1.7 +@test w[2] === 1.6 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[1] = 1 +w[2] = 1e8 +@test w[1] == 1 +@test w[2] === 1e8 + +let w = DynamicDiscreteSamplers.FixedSizeWeights(2) + w[1] = 1.1 + w[2] = 1.9 + twos = 0 + n = 10_000 + for _ in 1:n + x = rand(w) + @test x ∈ 1:2 + if x == 2 + twos += 1 + end + end + @test (w[2]/(w[1]+w[2])) === 1.9/3 + expected = n*1.9/3 + stdev = .5sqrt(n) + @test abs(twos-expected) < 4stdev +end + +let w = DynamicDiscreteSamplers.FixedSizeWeights(10) + for i in 1:40 + w[1] = 1.5*2.0^i + @test w[1] === 1.5*2.0^i + end +end + +w = DynamicDiscreteSamplers.ResizableWeights(10) +resize!(w, 20) +resize!(w, unsigned(30)) + +w = DynamicDiscreteSamplers.ResizableWeights(10) +w[5] = 3 +resize!(w, 20) +v = fill(0.0, 20) +v[5] = 3 +@test w == v + +@test rand(w) == 5 +w[11] = v[11] = 3.5 +@test w == v + +w = DynamicDiscreteSamplers.ResizableWeights(10) +w[1] = 1.2 +w[1] = 0 +resize!(w, 20) +w[15] = 1.3 +@test w[11] == 0 + +w = DynamicDiscreteSamplers.ResizableWeights(10) +w[1] = 1.2 +w[2] = 1.3 +w[2] = 0 +resize!(w, 20) + +w = DynamicDiscreteSamplers.ResizableWeights(10) +w[5] = 1.2 +w[6] = 1.3 +w[6] = 0 +resize!(w, 20) +w[15] = 2.1 +resize!(w, 40) +w[30] = 4.1 +w[22] = 2.2 # This previously threw + +w = DynamicDiscreteSamplers.ResizableWeights(10); +w[5] = 1.5 +resize!(w, 3) +resize!(w, 20) # This previously threw +@test w == fill(0.0, 20) + +w = DynamicDiscreteSamplers.ResizableWeights(2) +w[1] = .3 +w[2] = 1.1 +w[2] = .4 +w[2] = 2.1 +w[1] = .6 +w[2] = .7 # This used to throw +@test w == [.6, .7] + +w = DynamicDiscreteSamplers.ResizableWeights(1) +w[1] = 18 +w[1] = .9 +w[1] = 1.3 +w[1] = .01 +w[1] = .9 +@test w == [.9] +resize!(w, 2) +@test w == [.9, 0] + +w = DynamicDiscreteSamplers.ResizableWeights(2) +w[2] = 19 +w[2] = 10 +w[2] = .9 +w[1] = 2.1 +w[1] = 1.1 +w[1] = 0.7 +@test w == [.7, .9] + +w = DynamicDiscreteSamplers.ResizableWeights(6) +resize!(w, 2108) +w[296] = 3.686559798150465e39 +w[296] = 0 +w[1527] = 1.0763380850925863 +w[355] = 0.01640346013465141 +w[881] = 79.54017710382257 +w[437] = 3.848925751307115 +w[571] = 1.0339246678117338 +w[762] = 0.7965409844985439 +w[1814] = 1.3864105787251011e-12 +w[881] = 0 +w[1059] = 0.9443147177405427 +w[668] = 255825.83047903617 +w[23] = 1.0173292806984486 +w[377] = 6.652796808681465 +w[668] = 0 +w[1939] = 7.075668919342077e18 +w[979] = 0.8922993294513122 +resize!(w, 1612) # This previously threw an AssertionError: 48 <= Base.top_set_bit(m[4]) <= 49 + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +for x in (floatmin(Float64), prevfloat(1.0, 2), prevfloat(1.0), 1.0, nextfloat(1.0), nextfloat(1.0, 2), floatmax(Float64)) + w[1] = x # This previously threw on prevfloat(1.0) and floatmax(Float64) + @test w[1] === x +end + +include("invariants.jl") + +w = DynamicDiscreteSamplers.ResizableWeights(31) +w[11] = 9.923269000574892e-8 +w[23] = 0.9876032886161744 +w[31] = 1.1160998022859043 +verify(w) + +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[1] = floatmin(Float64) +w[2] = floatmax(Float64) +w[2] = 0 # This previously threw an assertion error due to overflow when estimating sum of level weights +verify(w) + +w = DynamicDiscreteSamplers.FixedSizeWeights(9) +v = zeros(9) +v[4] = w[4] = 2.44 +v[5] = w[5] = 0.76 +v[6] = w[6] = 0.61 +v[7] = w[7] = 0.62 +v[9] = w[9] = 2.15 +v[1] = w[1] = 1.65 +v[7] = w[7] = 1.46 +v[8] = w[8] = 0.25 +v[2] = w[2] = 0.93 +v[3] = w[3] = 3.67 +v[6] = w[6] = 9.92 +v[5] = w[5] = 1.72 +v[6] = w[6] = 0.70 +v[8] = w[8] = 0.72 +v[5] = w[5] = 0.20 +v[1] = w[1] = 0.71 +v[3] = w[3] = 0.92 +verify(w) +@test v == w + +w = DynamicDiscreteSamplers.ResizableWeights(2) +w[1] = 0.95 +w[2] = 6.41e14 +verify(w) + +# This test catches a bug that was not revealed by the RNG tests below +w = DynamicDiscreteSamplers.FixedSizeWeights(3); +w[1] = 1.5 +w[2] = prevfloat(1.5) +w[3] = 2^25 +verify(w) + +# This test catches a bug that was not revealed by the RNG tests below. +# The final line is calibrated to have about a 50% fail rate on that bug +# and run in about 3 seconds: +w = DynamicDiscreteSamplers.FixedSizeWeights(2046*2048) +w .= repeat(ldexp.(1.0, -1022:1023), inner=2048) +w[(2046-16)*2048+1:2046*2048] .= 0 +@test w.m[4] < 2.0^32*1.1 # Confirm that we created an interesting condition +f(w,n) = sum(Int64(rand(w)) for _ in 1:n) +verify(w) +@test f(w, 2^27) ≈ 4.1543685e6*2^27 rtol=1e-6 # This should fail less than 1e-40 of the time + +# These tests have never revealed a bug that was not revealed by one of the above tests: +w = DynamicDiscreteSamplers.FixedSizeWeights(10) +w[1] = 1 +w[2] = 1e100 +@test rand(w) === 2 +w[3] = 1e-100 +@test rand(w) === 2 +w[2] = 0 +@test rand(w) === 1 +w[1] = 0 +@test rand(w) === 3 +w[3] = 0 +@test_throws ArgumentError("collection must be non-empty") rand(w) + +let + for _ in 1:10000 + w = DynamicDiscreteSamplers.FixedSizeWeights(10) + v = [w[i] for i in 1:10] + for _ in 1:10 + i = rand(1:10) + x = rand((0.0, exp(10randn()))) + w[i] = x + v[i] = x + @test all(v[i] === w[i] for i in 1:10) + end + end +end + +# This alone probably catches all bugs that are caught by tests above (with one exception). +# However, whenever we identify and fix a bug, we add a specific test for it above. +include("statistical.jl") +try + let + print("weights.jl randomized tests: 0%") + for rep in 1:1000 + if rep % 10 == 0 + print("\rweights.jl randomized tests: $(rep÷10)%") + end + global LOG = [] + len = rand(1:100) + push!(LOG, len) + w = DynamicDiscreteSamplers.ResizableWeights(len) + v = fill(0.0, len) + resize = rand(Bool) # Some behavior emerges when not resizing for a long period + for _ in 1:rand((10,100,3000)) + @test v == w + verify(w) + if rand() < .01 + sm = sum(v) + sm == 0 || statistical_test(w, v ./ sm) + end + x = rand() + if x < .2 && !all(iszero, v) + i = rand(findall(!iszero, v)) + push!(LOG, i => 0) + v[i] = 0 + w[i] = 0 + elseif x < .4 && !all(iszero, v) + i = rand(w) + push!(LOG, i => 0) + v[i] = 0 + w[i] = 0 + elseif x < .9 || !resize + i = rand(eachindex(v)) + x = exp(rand((.1, 7, 100))*randn()) + push!(LOG, i => x) + v[i] = x + w[i] = x + else + l_old = length(v) + l_new = rand(1:rand((10,100,3000))) + push!(LOG, resize! => l_new) + resize!(v, l_new) + resize!(w, l_new) + if l_new > l_old + v[l_old+1:l_new] .= 0 + end + end + end + end + println() + end + println("These tests should fail due to random noise no more than $FALSE_POSITIVITY_ACCUMULATOR of the time") +catch + println("Reproducer:\n```julia") + for L in LOG + if L isa Int + println("w = DynamicDiscreteSamplers.ResizableWeights($L)") + elseif first(L) === resize! + println("resize!(w, $(last(L)))") + else + println("w[$(first(L))] = $(last(L))") + end + end + println("```") + rethrow() +end