diff --git a/Project.toml b/Project.toml index 094f9f3..334a0a0 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Flux = "0.10" GPUArrays = "2, 3" MacroTools = "0.5" ProgressMeter = "1.2" -ReinforcementLearningBase = "0.6" +ReinforcementLearningBase = "0.7" StatsBase = "0.32" Zygote = "0.4" julia = "1.3" diff --git a/src/ReinforcementLearningCore.jl b/src/ReinforcementLearningCore.jl index 746ea73..2106842 100644 --- a/src/ReinforcementLearningCore.jl +++ b/src/ReinforcementLearningCore.jl @@ -13,7 +13,7 @@ export RLCore include("extensions/extensions.jl") include("utils/utils.jl") -include("core/core.jl") include("components/components.jl") +include("core/core.jl") end # module diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl new file mode 100644 index 0000000..a4e8cee --- /dev/null +++ b/src/components/agents/abstract_agent.jl @@ -0,0 +1,68 @@ +export AbstractAgent, + get_role, + PreExperimentStage, + PostExperimentStage, + PreEpisodeStage, + PostEpisodeStage, + PreActStage, + PostActStage, + PRE_EXPERIMENT_STAGE, + POST_EXPERIMENT_STAGE, + PRE_EPISODE_STAGE, + POST_EPISODE_STAGE, + PRE_ACT_STAGE, + POST_ACT_STAGE + +""" + (agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action + (agent::AbstractAgent)(stage::AbstractStage, obs) + +Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action. +The main difference is that, we divide an experiment into the following stages: + +- `PRE_EXPERIMENT_STAGE` +- `PRE_EPISODE_STAGE` +- `PRE_ACT_STAGE` +- `POST_ACT_STAGE` +- `POST_EPISODE_STAGE` +- `POST_EXPERIMENT_STAGE` + +In each stage, different types of agents may have different behaviors, like updating experience buffer, environment model or policy. +""" +abstract type AbstractAgent end + +function get_role(::AbstractAgent) end + +""" + +-----------------------------------------------------------+ + |Episode | + | | +PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE | POST_EXPERIMENT_STAGE + | | | | | | + v | +-----+ v +-------+ v +-----+ | v + --------------------->+ env +------>+ agent +------->+ env +---> ... ------->...... + | ^ +-----+ obs +-------+ action +-----+ ^ | + | | | | + | +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ | + | | + | | + +-----------------------------------------------------------+ +""" +abstract type AbstractStage end + +struct PreExperimentStage <: AbstractStage end +struct PostExperimentStage <: AbstractStage end +struct PreEpisodeStage <: AbstractStage end +struct PostEpisodeStage <: AbstractStage end +struct PreActStage <: AbstractStage end +struct PostActStage <: AbstractStage end + +const PRE_EXPERIMENT_STAGE = PreExperimentStage() +const POST_EXPERIMENT_STAGE = PostExperimentStage() +const PRE_EPISODE_STAGE = PreEpisodeStage() +const POST_EPISODE_STAGE = PostEpisodeStage() +const PRE_ACT_STAGE = PreActStage() +const POST_ACT_STAGE = PostActStage() + +(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) +function (agent::AbstractAgent)(stage::AbstractStage, obs) end diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index 411bae7..149bd2e 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -11,15 +11,15 @@ Generally speaking, it does nothing but update the trajectory and policy appropr - `policy`::[`AbstractPolicy`](@ref): the policy to use - `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment -- `role=DEFAULT_PLAYER`: used to distinguish different agents +- `role=:DEFAULT_PLAYER`: used to distinguish different agents """ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent policy::P trajectory::T - role::R = DEFAULT_PLAYER + role::R = :DEFAULT_PLAYER end -RLBase.get_role(agent::Agent) = agent.role +get_role(agent::Agent) = agent.role ##### # EpisodicCompactSARTSATrajectory diff --git a/src/components/agents/agents.jl b/src/components/agents/agents.jl index ecf60b2..3d063e0 100644 --- a/src/components/agents/agents.jl +++ b/src/components/agents/agents.jl @@ -1,2 +1,3 @@ +include("abstract_agent.jl") include("agent.jl") include("dyna_agent.jl") diff --git a/src/components/agents/dyna_agent.jl b/src/components/agents/dyna_agent.jl index 2aee273..5c60eb4 100644 --- a/src/components/agents/dyna_agent.jl +++ b/src/components/agents/dyna_agent.jl @@ -27,11 +27,11 @@ Base.@kwdef struct DynaAgent{ policy::P model::M trajectory::B - role::R = DEFAULT_PLAYER + role::R = :DEFAULT_PLAYER plan_step::Int = 10 end -RLBase.get_role(agent::DynaAgent) = agent.role +get_role(agent::DynaAgent) = agent.role function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::PreEpisodeStage, diff --git a/src/components/approximators/abstract_approximator.jl b/src/components/approximators/abstract_approximator.jl new file mode 100644 index 0000000..556d678 --- /dev/null +++ b/src/components/approximators/abstract_approximator.jl @@ -0,0 +1,41 @@ +export AbstractApproximator, + ApproximatorStyle, + Q_APPROXIMATOR, + QApproximator, + V_APPROXIMATOR, + VApproximator + +""" + (app::AbstractApproximator)(obs) + +An approximator is a functional object for value estimation. +It serves as a black box to provides an abstraction over different +kinds of approximate methods (for example DNN provided by Flux or Knet). +""" +abstract type AbstractApproximator end + +""" + update!(a::AbstractApproximator, correction) + +Usually the `correction` is the gradient of inner parameters. +""" +function RLBase.update!(a::AbstractApproximator, correction) end + +##### +# traits +##### + +abstract type AbstractApproximatorStyle end + +""" +Used to detect what an [`AbstractApproximator`](@ref) is approximating. +""" +function ApproximatorStyle(::AbstractApproximator) end + +struct QApproximator <: AbstractApproximatorStyle end + +const Q_APPROXIMATOR = QApproximator() + +struct VApproximator <: AbstractApproximatorStyle end + +const V_APPROXIMATOR = VApproximator() diff --git a/src/components/approximators/approximators.jl b/src/components/approximators/approximators.jl index 1006c0c..c6b3b89 100644 --- a/src/components/approximators/approximators.jl +++ b/src/components/approximators/approximators.jl @@ -1,2 +1,3 @@ +include("abstract_approximator.jl") include("tabular_approximator.jl") include("neural_network_approximator.jl") diff --git a/src/components/approximators/neural_network_approximator.jl b/src/components/approximators/neural_network_approximator.jl index d80450a..6090d43 100644 --- a/src/components/approximators/neural_network_approximator.jl +++ b/src/components/approximators/neural_network_approximator.jl @@ -2,12 +2,6 @@ export NeuralNetworkApproximator using Flux -struct NeuralNetworkApproximator{T,M,O,P} <: AbstractApproximator - model::M - optimizer::O - params::P -end - """ NeuralNetworkApproximator(;kwargs) @@ -18,45 +12,18 @@ Use a DNN model for value estimation. - `model`, a Flux based DNN model. - `optimizer` - `parameters=params(model)` -- `kind=Q_APPROXIMATOR`, specify the type of model. """ -function NeuralNetworkApproximator(; - model::M, - optimizer::O, - parameters::P = params(model), - kind = Q_APPROXIMATOR, -) where {M,O,P} - NeuralNetworkApproximator{kind,M,O,P}(model, optimizer, parameters) +Base.@kwdef struct NeuralNetworkApproximator{M,O,P} <: AbstractApproximator + model::M + optimizer::O + params::P = params(model) end -device(app::NeuralNetworkApproximator) = device(app.model) +(app::NeuralNetworkApproximator)(x) = app.model(x) Flux.params(app::NeuralNetworkApproximator) = app.params -(app::NeuralNetworkApproximator)(s::AbstractArray) = app.model(s) -(app::NeuralNetworkApproximator{Q_APPROXIMATOR})(s::AbstractArray, a::Int) = app.model(s)[a] -(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:Q}) = - app.model(s, Val(:Q)) -(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:V}) = - app.model(s, Val(:V)) -(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, a::Int) = - app.model(s, Val(:Q))[a] - - -RLBase.batch_estimate(app::NeuralNetworkApproximator, states::AbstractArray) = - app.model(states) - -RLBase.batch_estimate( - app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR}, - states::AbstractArray, - ::Val{:Q}, -) = app.model(states, Val(:Q)) - -RLBase.batch_estimate( - app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR}, - states::AbstractArray, - ::Val{:V}, -) = app.model(states, Val(:V)) +device(app::NeuralNetworkApproximator) = device(app.model) RLBase.update!(app::NeuralNetworkApproximator, gs) = Flux.Optimise.update!(app.optimizer, app.params, gs) diff --git a/src/components/approximators/tabular_approximator.jl b/src/components/approximators/tabular_approximator.jl index cb5a74a..86be23e 100644 --- a/src/components/approximators/tabular_approximator.jl +++ b/src/components/approximators/tabular_approximator.jl @@ -3,7 +3,8 @@ export TabularApproximator """ TabularApproximator(table<:AbstractArray) -For `table` of 1-d, it will create a [`V_APPROXIMATOR`](@ref). For `table` of 2-d, it will create a [`QApproximator`]. +For `table` of 1-d, it will serve as a state value approximator. +For `table` of 2-d, it will serve as a state-action value approximator. !!! warning For `table` of 2-d, the first dimension is action and the second dimension is state. @@ -47,5 +48,5 @@ function RLBase.update!(Q::TabularApproximator{2}, correction::Pair{Int,Vector{F end end -RLBase.ApproximatorStyle(::TabularApproximator{1}) = VApproximator() -RLBase.ApproximatorStyle(::TabularApproximator{2}) = QApproximator() +ApproximatorStyle(::TabularApproximator{1}) = V_APPROXIMATOR +ApproximatorStyle(::TabularApproximator{2}) = Q_APPROXIMATOR \ No newline at end of file diff --git a/src/components/components.jl b/src/components/components.jl index 1c457f2..835a9c8 100644 --- a/src/components/components.jl +++ b/src/components/components.jl @@ -1,7 +1,6 @@ -include("learners/learners.jl") -include("policies/policies.jl") +include("trajectories/trajectories.jl") include("approximators/approximators.jl") include("explorers/explorers.jl") -include("trajectories/trajectories.jl") -include("preprocessors.jl") +include("learners/learners.jl") +include("policies/policies.jl") include("agents/agents.jl") diff --git a/src/components/explorers/UCB_explorer.jl b/src/components/explorers/UCB_explorer.jl index 05bc3b6..e4d5f13 100644 --- a/src/components/explorers/UCB_explorer.jl +++ b/src/components/explorers/UCB_explorer.jl @@ -1,32 +1,41 @@ export UCBExplorer +using Random + +mutable struct UCBExplorer{R<:AbstractRNG} <: AbstractExplorer + c::Float64 + actioncounts::Vector{Float64} + step::Int + rng::R +end + """ - UCBExplorer(na; c=2.0, ϵ=1e-10) + UCBExplorer(na; c=2.0, ϵ=1e-10, step=1, seed=nothing) + # Arguments - `na` is the number of actions used to create a internal counter. - `t` is used to store current time step. - `c` is used to control the degree of exploration. +- `seed`, set the seed of inner RNG. """ -mutable struct UCBExplorer <: AbstractExplorer - c::Float64 - actioncounts::Vector{Float64} - step::Int - UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1) = new(c, fill(ϵ, na), 1) -end +UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed=nothing) = UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed)) @doc raw""" (ucb::UCBExplorer)(values::AbstractArray) Unlike [`EpsilonGreedyExplorer`](@ref), uncertaintyies are considered in UCB. + !!! note If multiple values with the same maximum value are found. Then a random one will be returned! + ```math A_t = \underset{a}{\arg \max} \left[ Q_t(a) + c \sqrt{\frac{\ln t}{N_t(a)}} \right] ``` + See more details at Section (2.7) on Page 35 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.* """ function (p::UCBExplorer)(values::AbstractArray) - action = - find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts))[2] |> sample + v, inds = find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts)) + action = sample(p.rng, inds) p.actioncounts[action] += 1 p.step += 1 action diff --git a/src/components/explorers/abstract_explorer.jl b/src/components/explorers/abstract_explorer.jl new file mode 100644 index 0000000..8ae7f24 --- /dev/null +++ b/src/components/explorers/abstract_explorer.jl @@ -0,0 +1,26 @@ +export AbstractExplorer + +""" + (p::AbstractExplorer)(x) + (p::AbstractExplorer)(x, mask) + +Define how to select an action based on action values. +""" +abstract type AbstractExplorer end + +function (p::AbstractExplorer)(x) end +function (p::AbstractExplorer)(x, mask) end + +""" + get_prob(p::AbstractExplorer, x) -> AbstractDistribution + +Get the action distribution given action values. +""" +function RLBase.get_prob(p::AbstractExplorer, x) end + +""" + get_prob(p::AbstractExplorer, x, mask) + +Similart to `get_prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered. +""" +function RLBase.get_prob(p::AbstractExplorer, x, mask) end \ No newline at end of file diff --git a/src/components/explorers/batch_exploer.jl b/src/components/explorers/batch_exploer.jl index cec2eb1..3e3defb 100644 --- a/src/components/explorers/batch_exploer.jl +++ b/src/components/explorers/batch_exploer.jl @@ -2,7 +2,6 @@ export BatchExplorer """ BatchExplorer(explorer::AbstractExplorer) - BatchExplorer(explorers::AbstractExplorer...) """ struct BatchExplorer{E} <: AbstractExplorer explorer::E @@ -10,6 +9,12 @@ end BatchExplorer(explorers::AbstractExplorer...) = BatchExplorer(explorers) +""" + (x::BatchExplorer)(values::AbstractMatrix) + +Apply inner explorer to each column of `values`. +""" (x::BatchExplorer)(values::AbstractMatrix) = [x.explorer(v) for v in eachcol(values)] + (x::BatchExplorer{<:Tuple})(values::AbstractMatrix) = [explorer(v) for (explorer, v) in zip(x.explorer, eachcol(values))] diff --git a/src/components/explorers/epsilon_greedy_explorer.jl b/src/components/explorers/epsilon_greedy_explorer.jl index 2c22100..9e052ab 100644 --- a/src/components/explorers/epsilon_greedy_explorer.jl +++ b/src/components/explorers/epsilon_greedy_explorer.jl @@ -1,7 +1,7 @@ export EpsilonGreedyExplorer, GreedyExplorer using Random -using Distributions: DiscreteNonParametric +using Distributions:Categorical """ EpsilonGreedyExplorer{T}(;kwargs...) @@ -21,6 +21,8 @@ Two kinds of epsilon-decreasing strategy are implmented here (`linear` and `exp` - `warmup_steps::Int=0`: the number of steps to use `ϵ_init`. - `decay_steps::Int=0`: the number of steps for epsilon to decay from `ϵ_init` to `ϵ_stable`. - `ϵ_stable::Float64`: the epsilon after `warmup_steps + decay_steps`. +- `is_break_tie=false`: randomly select an action of the same maximum values if set to `true`. +- `seed=nothing`: set the seed of internal RNG. # Example @@ -45,17 +47,6 @@ mutable struct EpsilonGreedyExplorer{Kind,IsBreakTie,R} <: AbstractExplorer rng::R end -function Base.copy(p::EpsilonGreedyExplorer{Kind,IsBreakTie,R}) where {Kind,IsBreakTie,R} - EpsilonGreedyExplorer{Kind,IsBreakTie,R}( - p.ϵ_stable, - p.ϵ_init, - p.warmup_steps, - p.decay_steps, - p.step, - copy(p.rng), - ) -end - function EpsilonGreedyExplorer(; ϵ_stable, kind = :linear, @@ -121,7 +112,7 @@ end function (s::EpsilonGreedyExplorer{<:Any,false})(values) ϵ = get_ϵ(s) s.step += 1 - rand(s.rng) >= ϵ ? find_max(values)[2] : rand(s.rng, 1:length(values)) + rand(s.rng) >= ϵ ? findmax(values)[2] : rand(s.rng, 1:length(values)) end function (s::EpsilonGreedyExplorer{<:Any,true})(values, mask) @@ -134,14 +125,14 @@ end function (s::EpsilonGreedyExplorer{<:Any,false})(values, mask) ϵ = get_ϵ(s) s.step += 1 - rand(s.rng) >= ϵ ? find_max(values, mask)[2] : rand(s.rng, findall(mask)) + rand(s.rng) >= ϵ ? findmax(values, mask)[2] : rand(s.rng, findall(mask)) end Random.seed!(s::EpsilonGreedyExplorer, seed) = Random.seed!(s.rng, seed) """ - get_prob(s::EpsilonGreedyExplorer, values) -> DiscreteNonParametric - get_prob(s::EpsilonGreedyExplorer, values, mask) -> DiscreteNonParametric + get_prob(s::EpsilonGreedyExplorer, values) ->Categorical + get_prob(s::EpsilonGreedyExplorer, values, mask) ->Categorical Return the probability of selecting each action given the estimated `values` of each action. """ @@ -152,14 +143,33 @@ function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,true}, values) for ind in max_val_inds probs[ind] += (1 - ϵ) / length(max_val_inds) end - DiscreteNonParametric(1:length(probs), probs) + Categorical(probs) +end + +function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,true}, values, action::Integer) + ϵ, n = get_ϵ(s), length(values) + max_val_inds = find_all_max(values)[2] + if action in max_val_inds + ϵ / n + (1 - ϵ) / length(max_val_inds) + else + ϵ / n + end end function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,false}, values) ϵ, n = get_ϵ(s), length(values) probs = fill(ϵ / n, n) - probs[find_max(values)[2]] += 1 - ϵ - DiscreteNonParametric(1:length(probs), probs) + probs[findmax(values)[2]] += 1 - ϵ + Categorical(probs) +end + +function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,false}, values, action::Integer) + ϵ, n = get_ϵ(s), length(values) + if action == findmax(values)[2] + ϵ / n + 1 - ϵ + else + ϵ / n + end end function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,true}, values, mask) @@ -170,32 +180,36 @@ function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,true}, values, mask) for ind in max_val_inds probs[ind] += (1 - ϵ) / length(max_val_inds) end - DiscreteNonParametric(1:length(probs), probs) + Categorical(probs) end function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,false}, values, mask) ϵ, n = get_ϵ(s), length(values) probs = zeros(n) probs[mask] .= ϵ / sum(mask) - probs[find_max(values, mask)[2]] += 1 - ϵ - DiscreteNonParametric(1:length(probs), probs) + probs[findmax(values, mask)[2]] += 1 - ϵ + Categorical(probs) end RLBase.reset!(s::EpsilonGreedyExplorer) = s.step = 1 +# Though we can achieve the same goal by setting the ϵ of [`EpsilonGreedyExplorer`](@ref) to 0, +# the GreedyExplorer is much faster. struct GreedyExplorer <: AbstractExplorer end -(s::GreedyExplorer)(values) = find_max(values)[2] -(s::GreedyExplorer)(values, mask) = find_max(values, mask)[2] +(s::GreedyExplorer)(values) = findmax(values)[2] +(s::GreedyExplorer)(values, mask) = findmax(values, mask)[2] function RLBase.get_prob(s::GreedyExplorer, values) prob = zeros(length(values)) - prob[find_max(values)[2]] = 1.0 - DiscreteNonParametric(1:length(prob), prob) + prob[findmax(values)[2]] = 1.0 + Categorical(prob) end +RLBase.get_prob(s::GreedyExplorer, values, action::Integer) = findmax(values)[2] == action ? 1.0 : 0.0 + function RLBase.get_prob(s::GreedyExplorer, values, mask) prob = zeros(length(values)) - prob[find_max(values, mask)[2]] = 1.0 - DiscreteNonParametric(1:length(prob), prob) + prob[findmax(values, mask)[2]] = 1.0 + Categorical(prob) end diff --git a/src/components/explorers/explorers.jl b/src/components/explorers/explorers.jl index b276b22..ec4ee42 100644 --- a/src/components/explorers/explorers.jl +++ b/src/components/explorers/explorers.jl @@ -1,3 +1,4 @@ +include("abstract_explorer.jl") include("batch_exploer.jl") include("epsilon_greedy_explorer.jl") include("UCB_explorer.jl") diff --git a/src/components/learners/double_learner.jl b/src/components/learners/double_learner.jl deleted file mode 100644 index edfc1fd..0000000 --- a/src/components/learners/double_learner.jl +++ /dev/null @@ -1,30 +0,0 @@ -export DoubleLearner - -using Random - -""" - DoubleLearner(;L1, L2, rng=MersenneTwister()) - -This is a meta-learner, it will randomly select one learner and update another learner. -The estimation of an observation is the sum of result from two learners. -""" -Base.@kwdef struct DoubleLearner{T1<:AbstractLearner,T2<:AbstractLearner,R<:AbstractRNG} <: - AbstractLearner - L1::T1 - L2::T2 - rng::R = MersenneTwister() -end - -""" - DoubleLearner(l1, l2; seed = nothing) -""" -DoubleLearner(l1, l2; seed = nothing) = DoubleLearner(l1, l2, MersenneTwister(seed)) - -(learner::DoubleLearner)(obs) = learner.L1(obs) .+ learner.L2(obs) - -RLBase.extract_experience(t::AbstractTrajectory, learner::DoubleLearner) = - extract_experience(t, learner.L1) - -update!(learner::DoubleLearner, experience) = - rand(learner.rng, Bool) ? update!(learner.L1, experience) : - update!(learner.L2, experience) diff --git a/src/components/learners/learners.jl b/src/components/learners/learners.jl index 8929f39..ceae6a8 100644 --- a/src/components/learners/learners.jl +++ b/src/components/learners/learners.jl @@ -1 +1,27 @@ -include("double_learner.jl") +export AbstractLearner, extract_experience + +""" + (learner::AbstractLearner)(obs) + +A learner is usually used to estimate state values, state-action values or distributional values based on experiences. +""" +abstract type AbstractLearner end + +function (learner::AbstractLearner)(obs) end + +""" + update!(learner::AbstractLearner, experience) + +Typical `experience` is [`AbstractTrajectory`](@ref). +""" +function RLBase.update!(learner::AbstractLearner, t::AbstractTrajectory) + experience = extract_experience(t, learner) + isnothing(experience) || update!(learner, experience) +end + +function extract_experience end + +""" + get_priority(p::AbstractLearner, experience) +""" +function RLBase.get_priority(p::AbstractLearner, experience) end \ No newline at end of file diff --git a/src/components/policies/Q_based_policy.jl b/src/components/policies/Q_based_policy.jl index b7c78e1..a972204 100644 --- a/src/components/policies/Q_based_policy.jl +++ b/src/components/policies/Q_based_policy.jl @@ -1,36 +1,27 @@ export QBasedPolicy +using MacroTools: @forward + """ QBasedPolicy(;learner::Q, explorer::S) -Use a Q-`learner` to generate the estimations of actions and use `explorer` to get the action. +Use a Q-`learner` to generate estimations of action values. +Then an `explorer` is applied on the estimations to select an action. """ Base.@kwdef struct QBasedPolicy{Q<:AbstractLearner,E<:AbstractExplorer} <: AbstractPolicy learner::Q explorer::E end +(π::QBasedPolicy)(obs) = π(obs, ActionStyle(obs)) (π::QBasedPolicy)(obs, ::MinimalActionSet) = obs |> π.learner |> π.explorer (π::QBasedPolicy)(obs, ::FullActionSet) = π.explorer(π.learner(obs), get_legal_actions_mask(obs)) -function RLBase.update!(p::QBasedPolicy, t::AbstractTrajectory) - experience = extract_experience(t, p) - isnothing(experience) || update!(p.learner, experience) -end - -RLBase.update!( - p::QBasedPolicy, - m::AbstractEnvironmentModel, - t::AbstractTrajectory, - n::Int, -) = update!(p.learner, m, t, n) - -RLBase.extract_experience(trajectory::AbstractTrajectory, p::QBasedPolicy) = - extract_experience(trajectory, p.learner) +RLBase.get_prob(p::QBasedPolicy, obs) = get_prob(p, obs, ActionStyle(obs)) RLBase.get_prob(p::QBasedPolicy, obs, ::MinimalActionSet) = get_prob(p.explorer, p.learner(obs)) RLBase.get_prob(p::QBasedPolicy, obs, ::FullActionSet) = get_prob(p.explorer, p.learner(obs), get_legal_actions_mask(obs)) -RLBase.get_priority(p::QBasedPolicy, experience) = get_priority(p.learner, experience) +@forward QBasedPolicy.learner RLBase.get_priority, RLBase.update! diff --git a/src/components/policies/V_based_policy.jl b/src/components/policies/V_based_policy.jl index 48b0df6..1ee23ad 100644 --- a/src/components/policies/V_based_policy.jl +++ b/src/components/policies/V_based_policy.jl @@ -1,7 +1,9 @@ export VBasedPolicy +using MacroTools: @forward + """ - VBasedPolicy(;kwargs...) + VBasedPolicy(;learner, mapping, explorer=GreedyExplorer()) # Key words & Fields @@ -15,6 +17,8 @@ Base.@kwdef struct VBasedPolicy{L<:AbstractLearner,M,E<:AbstractExplorer} <: Abs explorer::E = GreedyExplorer() end +(p::VBasedPolicy)(obs) = p(obs, ActionStyle(obs)) + (p::VBasedPolicy)(obs, ::MinimalActionSet) = p.mapping(obs, p.learner) |> p.explorer function (p::VBasedPolicy)(obs, ::FullActionSet) @@ -22,21 +26,19 @@ function (p::VBasedPolicy)(obs, ::FullActionSet) p.explorer(action_values, get_legal_actions_mask(obs)) end -function RLBase.get_prob(p::VBasedPolicy, obs, ::MinimalActionSet) - get_prob(p.explorer, p.mapping(obs, p.learner)) -end +RLBase.get_prob(p::VBasedPolicy, obs, action::Integer) = get_prob(p, obs, ActionStyle(obs), action) + +RLBase.get_prob(p::VBasedPolicy, obs, ::MinimalActionSet) = get_prob(p.explorer, p.mapping(obs, p.learner)) +RLBase.get_prob(p::VBasedPolicy, obs, ::MinimalActionSet, action) = get_prob(p.explorer, p.mapping(obs, p.learner), action) function RLBase.get_prob(p::VBasedPolicy, obs, ::FullActionSet) action_values = p.mapping(obs, p.learner) get_prob(p.explorer, action_values, get_legal_actions_mask(obs)) end -RLBase.update!(p::VBasedPolicy, experience) = update!(p.learner, experience) - -function RLBase.update!(p::VBasedPolicy, t::AbstractTrajectory) - experience = extract_experience(t, p) - isnothing(experience) || update!(p, experience) +function RLBase.get_prob(p::VBasedPolicy, obs, ::FullActionSet, action) + action_values = p.mapping(obs, p.learner) + get_prob(p.explorer, action_values, get_legal_actions_mask(obs), action) end -RLBase.extract_experience(trajectory::AbstractTrajectory, p::VBasedPolicy) = - extract_experience(trajectory, p.learner) +@forward VBasedPolicy.learner RLBase.get_priority, RLBase.update! \ No newline at end of file diff --git a/src/components/policies/off_policy.jl b/src/components/policies/off_policy.jl index b118de3..23053b8 100644 --- a/src/components/policies/off_policy.jl +++ b/src/components/policies/off_policy.jl @@ -1,5 +1,7 @@ export OffPolicy +using MacroTools: @forward + """ OffPolicy(π_target::P, π_behavior::B) -> OffPolicy{P,B} """ @@ -10,20 +12,4 @@ end (π::OffPolicy)(obs) = π.π_behavior(obs) -function RLBase.update!(π::OffPolicy, t::AbstractTrajectory) - experience = extract_experience(t, π) - isnothing(experience) || update!(π, experience) -end - -function RLBase.update!(π::OffPolicy{<:VBasedPolicy}, transitions::NamedTuple) - # ??? define a `get_batch_prob` function for efficiency - weights = [ - get_prob(π.π_target, (state = s,), a) / get_prob(π.π_behavior, (state = s,), a) - for (s, a) in zip(transitions.states, transitions.actions) - ] # TODO: implement iterate interface for (SubArray of) CircularArrayBuffer - experience = merge(transitions, (weights = weights,)) - update!(π.π_target, experience) -end - -RLBase.extract_experience(t::AbstractTrajectory, π::OffPolicy) = - extract_experience(t, π.π_target) +@forward OffPolicy.π_behavior RLBase.get_priority, RLBase.get_prob \ No newline at end of file diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl index c3284dc..ab599f9 100644 --- a/src/components/policies/policies.jl +++ b/src/components/policies/policies.jl @@ -1,5 +1,3 @@ -include("random_policy.jl") -include("weighted_ramdom_policy.jl") include("V_based_policy.jl") include("Q_based_policy.jl") include("off_policy.jl") diff --git a/src/components/policies/random_policy.jl b/src/components/policies/random_policy.jl deleted file mode 100644 index 3645856..0000000 --- a/src/components/policies/random_policy.jl +++ /dev/null @@ -1,43 +0,0 @@ -export RandomPolicy - -using Random - -""" - RandomPolicy(action_space, rng) - -Randomly return a valid action. -""" -struct RandomPolicy{S<:AbstractSpace,R<:AbstractRNG} <: AbstractPolicy - action_space::S - rng::R -end - -Base.show(io::IO, p::RandomPolicy) = print(io, "RandomPolicy($(p.action_space))") - -Random.seed!(p::RandomPolicy, seed) = Random.seed!(p.rng, seed) - -""" - RandomPolicy(action_space; seed=nothing) -""" -RandomPolicy(s; seed = nothing) = RandomPolicy(s, MersenneTwister(seed)) - -""" - RandomPolicy(env::AbstractEnv; seed=nothing) -""" -RandomPolicy(env::AbstractEnv; seed = nothing) = - RandomPolicy(get_action_space(env), MersenneTwister(seed)) - -function (p::RandomPolicy)(obs, ::FullActionSet) - legal_actions = get_legal_actions(obs) - length(legal_actions) == 0 ? get_invalid_action(obs) : rand(p.rng, legal_actions) -end - -(p::RandomPolicy)(obs, ::MinimalActionSet) = rand(p.rng, p.action_space) -(p::RandomPolicy)(obs::BatchObs, ::MinimalActionSet) = - [rand(p.rng, p.action_space) for _ in 1:length(obs)] - -RLBase.update!(p::RandomPolicy, experience) = nothing - -RLBase.get_prob(p::RandomPolicy, s) = - fill(1 / length(p.action_space), length(p.action_space)) -RLBase.get_prob(p::RandomPolicy, s, a) = 1 / length(p.action_space) diff --git a/src/components/policies/weighted_ramdom_policy.jl b/src/components/policies/weighted_ramdom_policy.jl deleted file mode 100644 index 08cbc7f..0000000 --- a/src/components/policies/weighted_ramdom_policy.jl +++ /dev/null @@ -1,69 +0,0 @@ -export WeightedRandomPolicy - -using Random -using StatsBase: sample, Weights - -""" - WeightedRandomPolicy(actions, weight, sums, rng) - -Similar to [`RandomPolicy`](@ref), but the probability of -each action is set in advance instead of a uniform distribution. - -- `actions` are all possible actions. -- `weight` can be an 1-D or 2-D array. If it's 1-D, then the `weight` applies to all states. If it's 2-D, then the state is assume to be of `Int` and for different state, the corresponding weight is selected. -""" -struct WeightedRandomPolicy{N,A,W<:AbstractArray,S,R<:AbstractRNG} <: AbstractPolicy - actions::A - weight::W - sums::S - rng::R -end - -function WeightedRandomPolicy( - weight::W; - actions = axes(weights, 1), - seed = nothing, -) where {W<:AbstractArray} - rng = MersenneTwister(seed) - N = ndims(W) - - if N == 1 - sums = sum(weight) - elseif N == 2 - sums = vec(sum(weight, dims = 1)) - end - WeightedRandomPolicy{ndims(W),typeof(actions),W,typeof(sums),typeof(rng)}( - actions, - weight, - sums, - rng, - ) -end - -Random.seed!(p::WeightedRandomPolicy, seed) = Random.seed!(p.rng, seed) - -RLBase.update!(p::WeightedRandomPolicy, experience) = nothing - -(p::WeightedRandomPolicy{1})(obs, ::MinimalActionSet) = - sample(p.rng, p.actions, Weights(p.weight, p.sums)) - -function (p::WeightedRandomPolicy{1})(obs, ::FullActionSet) - legal_actions = get_legal_actions(obs) - legal_actions_mask = get_legal_actions_mask(obs) - masked_weight = @view p.weight[legal_actions_mask] - legal_actions[sample(p.rng, Weights(masked_weight))] -end - -function (p::WeightedRandomPolicy{2})(obs, ::MinimalActionSet) - s = get_state(obs) - weight = @view p.weight[:, s] - sample(p.rng, p.actions, Weights(weight, p.sums[s])) -end - -function (p::WeightedRandomPolicy{2})(obs, ::FullActionSet) - s = get_state(obs) - legal_actions = get_legal_actions(obs) - legal_actions_mask = get_legal_actions_mask(obs) - masked_weight = @view p.weight[legal_actions_mask, s] - legal_actions[sample(p.rng, Weights(masked_weight))] -end diff --git a/src/components/preprocessors.jl b/src/components/preprocessors.jl deleted file mode 100644 index 56ca552..0000000 --- a/src/components/preprocessors.jl +++ /dev/null @@ -1,28 +0,0 @@ -export CloneStatePreprocessor, ComposedPreprocessor - -(p::AbstractPreprocessor)(obs) = StateOverriddenObs(obs = obs, state = p(get_state(obs))) - -""" - ComposedPreprocessor(p::AbstractPreprocessor...) - -Compose multiple preprocessors. -""" -struct ComposedPreprocessor{T} <: AbstractPreprocessor - preprocessors::T -end - -ComposedPreprocessor(p::AbstractPreprocessor...) = ComposedPreprocessor(p) -(p::ComposedPreprocessor)(obs) = reduce((x, f) -> f(x), p.preprocessors, init = obs) - -##### -# CloneStatePreprocessor -##### - -""" - CloneStatePreprocessor() - -Do `deepcopy` for the state in an observation. -""" -struct CloneStatePreprocessor <: AbstractPreprocessor end - -(p::CloneStatePreprocessor)(obs) = StateOverriddenObs(obs, deepcopy(get_state(obs))) diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl new file mode 100644 index 0000000..eb80680 --- /dev/null +++ b/src/components/trajectories/abstract_trajectory.jl @@ -0,0 +1,91 @@ +export AbstractTrajectory, get_trace, RTSA, SARTSA + +""" + AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1} + +A trajectory is used to record some useful information +during the interactions between agents and environments. + +# Parameters +- `names`::`NTuple{Symbol}`, indicate what fields to be recorded. +- `types`::`Tuple{DataType...}`, the datatypes of `names`. + +The length of `names` and `types` must match. + +Required Methods: + +- [`get_trace`](@ref) +- `Base.push!(t::AbstractTrajectory, kv::Pair{Symbol})` +- `Base.pop!(t::AbstractTrajectory, s::Symbol)` + +Optional Methods: + +- `Base.length` +- `Base.size` +- `Base.lastindex` +- `Base.isempty` +- `Base.empty!` +""" +abstract type AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1} end + +# some typical trace names +"An alias of `(:reward, :terminal, :state, :action)`" +const RTSA = (:reward, :terminal, :state, :action) + +"An alias of `(:state, :action, :reward, :terminal, :next_state, :next_action)`" +const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) + +""" + get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N} +""" +get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N} = + NamedTuple{s}(get_trace(t, x) for x in s) + +""" + get_trace(t::AbstractTrajectory, s::Symbol...) +""" +get_trace(t::AbstractTrajectory, s::Symbol...) = get_trace(t, s) + +""" + get_trace(t::AbstractTrajectory{names}) where {names} +""" +get_trace(t::AbstractTrajectory{names}) where {names} = + NamedTuple{names}(get_trace(t, x) for x in names) + +Base.length(t::AbstractTrajectory) = maximum(length(x) for x in get_trace(t)) +Base.size(t::AbstractTrajectory) = (length(t),) +Base.lastindex(t::AbstractTrajectory) = length(t) +Base.getindex(t::AbstractTrajectory{names,types}, i::Int) where {names,types} = NamedTuple{names,types}(Tuple(x[i] for x in get_trace(t))) + +Base.isempty(t::AbstractTrajectory) = all(isempty(t) for t in get_trace(t)) + +function Base.empty!(t::AbstractTrajectory) + for x in get_trace(t) + empty!(x) + end +end + +""" + Base.push!(t::AbstractTrajectory; kwargs...) +""" +function Base.push!(t::AbstractTrajectory; kwargs...) + for kv in kwargs + push!(t, kv) + end +end + +""" + Base.pop!(t::AbstractTrajectory{names}) where {names} +`pop!` out one element of each trace in `t` +""" +function Base.pop!(t::AbstractTrajectory{names}) where {names} + pop!(t, names...) +end + +""" + Base.pop!(t::AbstractTrajectory, s::Symbol...) +`pop!` out one element of the traces specified in `s` +""" +function Base.pop!(t::AbstractTrajectory, s::Symbol...) + NamedTuple{s}(pop!(t, x) for x in s) +end \ No newline at end of file diff --git a/src/components/trajectories/circular_compact_PSARTSA_buffer.jl b/src/components/trajectories/circular_compact_PSARTSA_buffer.jl index 60dd63f..861a09a 100644 --- a/src/components/trajectories/circular_compact_PSARTSA_buffer.jl +++ b/src/components/trajectories/circular_compact_PSARTSA_buffer.jl @@ -44,7 +44,7 @@ end @forward CircularCompactPSARTSATrajectory.trajectory Base.length, Base.isempty -RLBase.get_trace(t::CircularCompactPSARTSATrajectory, s::Symbol) = +get_trace(t::CircularCompactPSARTSATrajectory, s::Symbol) = s == :priority ? t.priority : get_trace(t.trajectory, s) function Base.getindex(b::CircularCompactPSARTSATrajectory, i::Int) diff --git a/src/components/trajectories/common.jl b/src/components/trajectories/common.jl index 787c42e..babd0aa 100644 --- a/src/components/trajectories/common.jl +++ b/src/components/trajectories/common.jl @@ -1,7 +1,7 @@ const CompactSARTSATrajectory = Union{CircularCompactSARTSATrajectory,VectorialCompactSARTSATrajectory} -function RLBase.get_trace(b::CompactSARTSATrajectory, s::Symbol) +function get_trace(b::CompactSARTSATrajectory, s::Symbol) if s == :state || s == :action select_last_dim(b[s], 1:(nframes(b[s]) > 1 ? nframes(b[s]) - 1 : nframes(b[s]))) elseif s == :reward || s == :terminal diff --git a/src/components/trajectories/episodic_compact_SARTSA_buffer.jl b/src/components/trajectories/episodic_compact_SARTSA_buffer.jl index 169aea0..6170aa3 100644 --- a/src/components/trajectories/episodic_compact_SARTSA_buffer.jl +++ b/src/components/trajectories/episodic_compact_SARTSA_buffer.jl @@ -25,7 +25,7 @@ Base.push!, Base.pop! # avoid method ambiguous -RLBase.get_trace(t::EpisodicCompactSARTSATrajectory, s::Symbol) = +get_trace(t::EpisodicCompactSARTSATrajectory, s::Symbol) = get_trace(t.trajectories, s) Base.getindex(t::EpisodicCompactSARTSATrajectory, i::Int) = getindex(t.trajectories, i) Base.pop!(t::EpisodicCompactSARTSATrajectory, s::Symbol...) = pop!(t.trajectories, s...) diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl index 8386e75..050a8f5 100644 --- a/src/components/trajectories/trajectories.jl +++ b/src/components/trajectories/trajectories.jl @@ -1,3 +1,4 @@ +include("abstract_trajectory.jl") include("trajectory.jl") include("vectorial_trajectory.jl") include("circular_trajectory.jl") diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 5dcfb9e..0282835 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -13,7 +13,7 @@ end "A helper function to access inner fields" Base.getindex(t::Trajectory, s::Symbol) = getproperty(t.trajectories, s) -RLBase.get_trace(t::Trajectory, s::Symbol) = t[s] +get_trace(t::Trajectory, s::Symbol) = t[s] function Base.push!(t::Trajectory, kv::Pair{Symbol}) k, v = kv diff --git a/src/core/run.jl b/src/core/run.jl index ff2f26f..f11edb0 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -1,6 +1,6 @@ import Base: run -run(agent, env, args...) = run(DynamicStyle(env), agent, env, args...) +run(agent, env::AbstractEnv, args...) = run(DynamicStyle(env), agent, env, args...) function run( ::Sequential, diff --git a/src/utils/base.jl b/src/utils/base.jl index 60cdbdd..81cd829 100644 --- a/src/utils/base.jl +++ b/src/utils/base.jl @@ -4,7 +4,6 @@ export nframes, select_last_frame, consecutive_view, find_all_max, - find_max, huber_loss, huber_loss_unreduced, discount_rewards, @@ -90,88 +89,22 @@ consecutive_view(cb::AbstractArray, inds::Vector{Int}, n_stack::Int, n_horizeon: ), ) -""" - find_all_max(A::AbstractArray) - -Like `find_max`, but all the indices of the maximum value are returned. - -!!! warning - All elements of value `NaN` in `A` will be ignored, unless all elements are `NaN`. - In that case, the returned maximum value will be `NaN` and the returned indices will be `collect(1:length(A))` - -# Examples - -```julia-repl -julia> find_all_max([-Inf, -Inf, -Inf]) -(-Inf, [1, 2, 3]) -julia> find_all_max([Inf, Inf, Inf]) -(Inf, [1, 2, 3]) -julia> find_all_max([Inf, 0, Inf]) -(Inf, [1, 3]) -julia> find_all_max([0,1,2,1,2,1,0]) -(2, [3, 5]) -``` -""" -function find_all_max(A) - maxval = typemin(eltype(A)) - idxs = Int[] - for (i, x) in enumerate(A) - if !isnan(x) - if x > maxval - maxval = x - empty!(idxs) - push!(idxs, i) - elseif x == maxval - push!(idxs, i) - end - end - end - if length(idxs) == 0 - NaN, collect(1:length(A)) - else - maxval, idxs - end +function find_all_max(x) + v = maximum(x) + v, findall(==(v), x) end -""" - find_all_max(A, mask) - -Similar to `find_all_max(A)`, but only the masked elements in `A` will be considered. -""" -function find_all_max(A, mask) - maxval = typemin(eltype(A)) - idxs = Int[] - for (i, x) in enumerate(A) - if mask[i] && (!isnan(x)) - if x > maxval - maxval = x - empty!(idxs) - push!(idxs, i) - elseif x == maxval - push!(idxs, i) - end - end - end - if length(idxs) == 0 - NaN, collect(1:length(A)) - else - maxval, idxs - end +function find_all_max(x, mask::AbstractVector{Bool}) + v = maximum(view(x,mask)) + v, [k for (m, k) in zip(mask, keys(x)) if m && x[k]==v] end -find_max(A) = findmax(A) +# !!! watch https://github.com/JuliaLang/julia/pull/35316#issuecomment-622629895 +Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain) +_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) -function find_max(A, mask) - maxval = typemin(eltype(A)) - ind = 0 - for (i, x) in enumerate(A) - if mask[i] && x >= maxval - maxval = x - ind = i - end - end - maxval, ind -end +# !!! type piracy +Base.findmax(A::AbstractVector, mask::AbstractVector{Bool}) = findmax(i -> A[i], view(keys(A), mask)) function logitcrossentropy_unreduced(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) return vec(-sum(y .* logsoftmax(logŷ), dims = 1)) diff --git a/test/components/approximators.jl b/test/components/approximators.jl index bb121d1..02c9709 100644 --- a/test/components/approximators.jl +++ b/test/components/approximators.jl @@ -16,12 +16,6 @@ q_values = NN(rand(2)) @test size(q_values) == (3,) - v = NN(rand(2), 2) - @test v isa Number - - batch_q_values = batch_estimate(NN, rand(2, 5)) - @test size(batch_q_values) == (3, 5) - gs = gradient(params(NN)) do sum(NN(rand(2, 5))) end diff --git a/test/components/components.jl b/test/components/components.jl index 312770f..6763cce 100644 --- a/test/components/components.jl +++ b/test/components/components.jl @@ -1,6 +1,4 @@ include("approximators.jl") -include("preprocessors.jl") include("explorers.jl") -include("policies.jl") include("trajectories.jl") include("agents.jl") diff --git a/test/components/environments.jl b/test/components/environments.jl deleted file mode 100644 index 3fab824..0000000 --- a/test/components/environments.jl +++ /dev/null @@ -1,26 +0,0 @@ -@testset "environments" begin - - @testset "test API" begin - - for env in [CartPoleEnv(), WrappedEnv(CloneStatePreprocessor(), CartPoleEnv())] - reset!(env) - action_space = get_action_space(env) - observation_space = get_observation_space(env) - - obs = observe(env) - - for _ in 1:1000 - if get_terminal(obs) - reset!(env) - obs = observe(env) - end - @test get_state(obs) ∈ observation_space - action = rand(action_space) - env(action) - obs = observe(env) - end - end - - end - -end diff --git a/test/components/explorers.jl b/test/components/explorers.jl index 9e587f1..411120e 100644 --- a/test/components/explorers.jl +++ b/test/components/explorers.jl @@ -20,14 +20,6 @@ target_prob; atol = 0.005, )) - - explorer_copy = copy(explorer) - reset!(explorer_copy) - Random.seed!(explorer_copy, 123) - - new_actions = [explorer_copy(values) for _ in 1:10000] - - @test actions == new_actions end @testset "linear" begin @@ -52,8 +44,14 @@ explorer(xs) end - reset!(explorer) - + explorer = EpsilonGreedyExplorer(; + ϵ_stable = 0.1, + ϵ_init = 0.9, + warmup_steps = 3, + decay_steps = 8, + kind = :linear, + is_break_tie = true, + ) for ϵ in E @test RLCore.get_ϵ(explorer) ≈ ϵ @test isapprox( @@ -63,7 +61,14 @@ explorer(xs) end - reset!(explorer) + explorer = EpsilonGreedyExplorer(; + ϵ_stable = 0.1, + ϵ_init = 0.9, + warmup_steps = 3, + decay_steps = 8, + kind = :linear, + is_break_tie = true, + ) for i in 1:100 @test mask[explorer(xs, mask)] end @@ -98,11 +103,6 @@ [ϵ / 5, ϵ / 5, ϵ / 5 + (1 - ϵ) / 2, ϵ / 5, ϵ / 5 + (1 - ϵ) / 2]; atol = 1e-5, ) - - reset!(explorer) - for i in 1:100 - @test mask[explorer(xs, mask)] - end end end diff --git a/test/components/policies.jl b/test/components/policies.jl deleted file mode 100644 index b1f5d97..0000000 --- a/test/components/policies.jl +++ /dev/null @@ -1,89 +0,0 @@ -@testset "policies" begin - - @testset "RandomPolicy" begin - p = RandomPolicy(DiscreteSpace(3)) - obs = (reward = 0.0, terminal = false, state = 1) - - Random.seed!(p, 321) - actions = [p(obs) for _ in 1:100] - Random.seed!(p, 321) - new_actions = [p(obs) for _ in 1:100] - @test actions == new_actions - end - - @testset "WeightedRandomPolicy" begin - @testset "1D" begin - weight = [1, 2, 3] - ratio = weight ./ sum(weight) - N = 1000 - actions = [:a, :b, :c] - p = WeightedRandomPolicy(weight, actions = actions, seed = 123) - - samples = [p(nothing, MINIMAL_ACTION_SET) for _ in 1:N] - stats = countmap(samples) - for (a, r) in zip(actions, ratio) - @test isapprox(r, stats[a] / N, atol = 0.05) - end - - legal_actions = [:aa, :cc] - legal_actions_mask = [true, false, true] - obs = ( - reward = 0.0, - terminal = false, - state = nothing, - legal_actions = legal_actions, - legal_actions_mask = legal_actions_mask, - ) - samples = [p(obs) for _ in 1:N] - stats = countmap(samples) - - weighted_ratio = ratio[legal_actions_mask] ./ sum(ratio[legal_actions_mask]) - for i in 1:length(legal_actions) - @test isapprox(stats[legal_actions[i]] / N, weighted_ratio[i], atol = 0.05) - end - end - - @testset "2D" begin - n_state = 2 - weight = reshape(1:6, 3, n_state) - ratio = weight ./ sum(weight, dims = 1) - N = 1000 - actions = [:a, :b, :c] - - p = WeightedRandomPolicy(weight, actions = actions, seed = 123) - - for state in 1:n_state - samples = [p((state = state,), MINIMAL_ACTION_SET) for _ in 1:N] - stats = countmap(samples) - for (a, r) in zip(actions, ratio[:, state]) - @test isapprox(r, stats[a] / N, atol = 0.05) - end - end - - for state in 1:n_state - legal_actions = [:aa, :cc] - legal_actions_mask = [true, false, true] - obs = ( - reward = 0.0, - terminal = false, - state = state, - legal_actions = legal_actions, - legal_actions_mask = legal_actions_mask, - ) - samples = [p(obs) for _ in 1:N] - stats = countmap(samples) - - weighted_ratio = - ratio[legal_actions_mask, state] ./ - sum(ratio[legal_actions_mask, state]) - for i in 1:length(legal_actions) - @test isapprox( - stats[legal_actions[i]] / N, - weighted_ratio[i], - atol = 0.05, - ) - end - end - end - end -end diff --git a/test/components/preprocessors.jl b/test/components/preprocessors.jl deleted file mode 100644 index 2ddef5f..0000000 --- a/test/components/preprocessors.jl +++ /dev/null @@ -1,8 +0,0 @@ -@testset "preprocessors" begin - obs1 = (state = [1, 2, 3],) - p = CloneStatePreprocessor() - obs2 = p(obs1) - - @test get_state(obs1) !== get_state(obs2) - @test get_state(obs1) == get_state(obs2) -end