11
2+ using Distributions
3+
24# Inspired by https://www.aarondefazio.com/tangentially/?p=58
35
46mutable struct DynamicInfo
@@ -148,9 +150,33 @@ function Base.append!(sp::DynamicSampler, inds::Union{UnitRange, AbstractArray},
148150 return sp
149151end
150152
151- Base. rand (sp:: DynamicSampler , n:: Integer ) = [rand (sp) for _ in 1 : n]
153+ function Base. rand (sp:: DynamicSampler , n:: Integer )
154+ sp. info. totweight = sum (sp. level_weights)
155+ n_each = rand (sp. rng, Multinomial (n, sp. level_weights ./ sp. info. totweight))
156+ randinds = Vector {Int} (undef, n)
157+ q = 1
158+ for (i, k) in enumerate (n_each)
159+ bucket = sp. level_buckets[i]
160+ level_size = length (bucket)
161+ for _ in 1 : k
162+ randinds[q] = extract_rand_idx (sp, i, bucket, level_size)[1 ]
163+ q += 1
164+ end
165+ end
166+ return randinds
167+ end
168+
152169@inline function Base. rand (sp:: DynamicSampler )
153- # Sample a level using the CDF method
170+ level, bucket, level_size = extract_rand_level (sp)
171+ idx, weight, level, idx_in_level = extract_rand_idx (sp, level, bucket, level_size)
172+ sp. info. idx = idx
173+ sp. info. weight = weight
174+ sp. info. level = level
175+ sp. info. idx_in_level = idx_in_level
176+ return idx
177+ end
178+
179+ @inline function extract_rand_level (sp:: DynamicSampler )
154180 u = rand (sp. rng) * sp. info. totweight
155181 cumulative_weight = 0.0
156182 level = length (sp. level_weights)
@@ -170,6 +196,10 @@ Base.rand(sp::DynamicSampler, n::Integer) = [rand(sp) for _ in 1:n]
170196 level, bucket = first (Iterators. drop (notempty, rand_notempty- 1 ))
171197 level_size = length (bucket)
172198 end
199+ return level, bucket, level_size
200+ end
201+
202+ @inline function extract_rand_idx (sp, level, bucket, level_size)
173203 level_max = sp. level_max[level]
174204 # Now sample within the level using rejection sampling
175205 u = rand (sp. rng) * level_size
@@ -184,11 +214,7 @@ Base.rand(sp::DynamicSampler, n::Integer) = [rand(sp) for _ in 1:n]
184214 idx_in_level = intu + 1
185215 idx, weight = bucket[idx_in_level]
186216 end
187- sp. info. idx = idx
188- sp. info. weight = weight
189- sp. info. level = level
190- sp. info. idx_in_level = idx_in_level
191- return idx
217+ return idx, weight, level, idx_in_level
192218end
193219
194220function Base. delete! (sp:: DynamicSampler , indices:: Union{UnitRange, Vector{<:Integer}} )
0 commit comments