Skip to content

Commit 6f9bb22

Browse files
authored
Some more improvements (#20)
1 parent e4f7ab2 commit 6f9bb22

3 files changed

Lines changed: 43 additions & 8 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
name = "DynamicSampling"
22
uuid = "2083aeaf-6258-5d07-89fc-32cf5060c837"
3-
version = "0.4.6"
3+
version = "0.4.5"
44

55
[deps]
6+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
67
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78

89
[compat]
10+
Distributions = "0.25"
911
Random = "1"
1012
julia = "1.6"

src/DynamicWeightedSampler.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11

2+
using Distributions
3+
24
# Inspired by https://www.aarondefazio.com/tangentially/?p=58
35

46
mutable struct DynamicInfo
@@ -148,9 +150,33 @@ function Base.append!(sp::DynamicSampler, inds::Union{UnitRange, AbstractArray},
148150
return sp
149151
end
150152

151-
Base.rand(sp::DynamicSampler, n::Integer) = [rand(sp) for _ in 1:n]
153+
function Base.rand(sp::DynamicSampler, n::Integer)
154+
sp.info.totweight = sum(sp.level_weights)
155+
n_each = rand(sp.rng, Multinomial(n, sp.level_weights ./ sp.info.totweight))
156+
randinds = Vector{Int}(undef, n)
157+
q = 1
158+
for (i, k) in enumerate(n_each)
159+
bucket = sp.level_buckets[i]
160+
level_size = length(bucket)
161+
for _ in 1:k
162+
randinds[q] = extract_rand_idx(sp, i, bucket, level_size)[1]
163+
q += 1
164+
end
165+
end
166+
return randinds
167+
end
168+
152169
@inline function Base.rand(sp::DynamicSampler)
153-
# Sample a level using the CDF method
170+
level, bucket, level_size = extract_rand_level(sp)
171+
idx, weight, level, idx_in_level = extract_rand_idx(sp, level, bucket, level_size)
172+
sp.info.idx = idx
173+
sp.info.weight = weight
174+
sp.info.level = level
175+
sp.info.idx_in_level = idx_in_level
176+
return idx
177+
end
178+
179+
@inline function extract_rand_level(sp::DynamicSampler)
154180
u = rand(sp.rng) * sp.info.totweight
155181
cumulative_weight = 0.0
156182
level = length(sp.level_weights)
@@ -170,6 +196,10 @@ Base.rand(sp::DynamicSampler, n::Integer) = [rand(sp) for _ in 1:n]
170196
level, bucket = first(Iterators.drop(notempty, rand_notempty-1))
171197
level_size = length(bucket)
172198
end
199+
return level, bucket, level_size
200+
end
201+
202+
@inline function extract_rand_idx(sp, level, bucket, level_size)
173203
level_max = sp.level_max[level]
174204
# Now sample within the level using rejection sampling
175205
u = rand(sp.rng) * level_size
@@ -184,11 +214,7 @@ Base.rand(sp::DynamicSampler, n::Integer) = [rand(sp) for _ in 1:n]
184214
idx_in_level = intu + 1
185215
idx, weight = bucket[idx_in_level]
186216
end
187-
sp.info.idx = idx
188-
sp.info.weight = weight
189-
sp.info.level = level
190-
sp.info.idx_in_level = idx_in_level
191-
return idx
217+
return idx, weight, level, idx_in_level
192218
end
193219

194220
function Base.delete!(sp::DynamicSampler, indices::Union{UnitRange, Vector{<:Integer}})

test/weighted_sampler_tests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
chisq_test = ChisqTest(counts_est, ps_exact)
4848
@test pvalue(chisq_test) > 0.05
4949

50+
samples_counts = countmap(rand(s3, 10^5))
51+
counts_est = [samples_counts[i] for i in 1:b]
52+
ps_exact = [i/((b ÷ 2)*(b+1)) for i in 1:b]
53+
54+
chisq_test = ChisqTest(counts_est, ps_exact)
55+
@test pvalue(chisq_test) > 0.05
56+
5057
for i in 1:(b ÷ 2)
5158
delete!(s3, i)
5259
end

0 commit comments

Comments
 (0)