Skip to content

Commit bc3931f

Browse files
committed
perf: optimize permutations by implementing it via multiset_permutations
As it currently stands, `multiset_permutations` is more efficient than `permutations`; see #151. We can exploit it to optimize `permutations`.
1 parent ab33a23 commit bc3931f

File tree

1 file changed

+19
-35
lines changed

1 file changed

+19
-35
lines changed

src/permutations.jl

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,26 @@ struct Permutations{T}
1414
length::Int
1515
end
1616

17-
function has_repeats(state::Vector{Int})
18-
# This can be safely marked inbounds because of the type restriction in the signature.
19-
# If the type restriction is ever loosened, please check safety of the `@inbounds`
20-
@inbounds for outer in eachindex(state)
21-
for inner in (outer+1):lastindex(state)
22-
if state[outer] == state[inner]
23-
return true
24-
end
25-
end
26-
end
27-
return false
28-
end
29-
30-
function increment!(state::Vector{Int}, min::Int, max::Int)
31-
state[end] += 1
32-
for i in reverse(eachindex(state))[firstindex(state):end-1]
33-
if state[i] > max
34-
state[i] = min
35-
state[i-1] += 1
36-
end
37-
end
38-
end
39-
40-
function next_permutation!(state::Vector{Int}, min::Int, max::Int)
41-
while true
42-
increment!(state, min, max)
43-
has_repeats(state) || break
44-
end
45-
end
46-
47-
function Base.iterate(p::Permutations, state::Vector{Int}=fill(firstindex(p.data), p.length))
48-
next_permutation!(state, firstindex(p.data), lastindex(p.data))
49-
if first(state) > lastindex(p.data)
50-
return nothing
17+
# The following code basically implements `permutations` in terms of `multiset_permutations` as
18+
#
19+
# permutations(a, t::Integer=length(a)) = Iterators.map(
20+
# indices -> [a[i] for i in indices],
21+
# multiset_permutations(eachindex(a), t))
22+
#
23+
# with the difference that we can also define `eltype(::Permutations)`, which is used in some tests.
24+
25+
function Base.iterate(p::Permutations, state=nothing)
26+
if isnothing(state)
27+
mp = multiset_permutations(eachindex(p.data), p.length)
28+
it = iterate(mp)
29+
if isnothing(it) return nothing end
30+
else
31+
mp, mp_state = state
32+
it = iterate(mp, mp_state)
33+
if isnothing(it) return nothing end
5134
end
52-
[p.data[i] for i in state], state
35+
indices, mp_state = it
36+
return [p.data[i] for i in indices], (; mp, mp_state)
5337
end
5438

5539
function Base.length(p::Permutations)

0 commit comments

Comments
 (0)