Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit acf0427

Browse files
authored
Simplify APIs (#47)
* clean up policies * clean up learners * clean up explorers * move some preprocessors into RLBase * clean up agents * add approximators * update test cases * minor fix * sync * bugfix * remove unused struct * update RLBase conpat version
1 parent 3fa6d53 commit acf0427

39 files changed

+398
-538
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Flux = "0.10"
2828
GPUArrays = "2, 3"
2929
MacroTools = "0.5"
3030
ProgressMeter = "1.2"
31-
ReinforcementLearningBase = "0.6"
31+
ReinforcementLearningBase = "0.7"
3232
StatsBase = "0.32"
3333
Zygote = "0.4"
3434
julia = "1.3"

src/ReinforcementLearningCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export RLCore
1313

1414
include("extensions/extensions.jl")
1515
include("utils/utils.jl")
16-
include("core/core.jl")
1716
include("components/components.jl")
17+
include("core/core.jl")
1818

1919
end # module
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
export AbstractAgent,
2+
get_role,
3+
PreExperimentStage,
4+
PostExperimentStage,
5+
PreEpisodeStage,
6+
PostEpisodeStage,
7+
PreActStage,
8+
PostActStage,
9+
PRE_EXPERIMENT_STAGE,
10+
POST_EXPERIMENT_STAGE,
11+
PRE_EPISODE_STAGE,
12+
POST_EPISODE_STAGE,
13+
PRE_ACT_STAGE,
14+
POST_ACT_STAGE
15+
16+
"""
17+
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action
18+
(agent::AbstractAgent)(stage::AbstractStage, obs)
19+
20+
Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action.
21+
The main difference is that, we divide an experiment into the following stages:
22+
23+
- `PRE_EXPERIMENT_STAGE`
24+
- `PRE_EPISODE_STAGE`
25+
- `PRE_ACT_STAGE`
26+
- `POST_ACT_STAGE`
27+
- `POST_EPISODE_STAGE`
28+
- `POST_EXPERIMENT_STAGE`
29+
30+
In each stage, different types of agents may have different behaviors, like updating experience buffer, environment model or policy.
31+
"""
32+
abstract type AbstractAgent end
33+
34+
function get_role(::AbstractAgent) end
35+
36+
"""
37+
+-----------------------------------------------------------+
38+
|Episode |
39+
| |
40+
PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE | POST_EXPERIMENT_STAGE
41+
| | | | | |
42+
v | +-----+ v +-------+ v +-----+ | v
43+
--------------------->+ env +------>+ agent +------->+ env +---> ... ------->......
44+
| ^ +-----+ obs +-------+ action +-----+ ^ |
45+
| | | |
46+
| +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ |
47+
| |
48+
| |
49+
+-----------------------------------------------------------+
50+
"""
51+
abstract type AbstractStage end
52+
53+
struct PreExperimentStage <: AbstractStage end
54+
struct PostExperimentStage <: AbstractStage end
55+
struct PreEpisodeStage <: AbstractStage end
56+
struct PostEpisodeStage <: AbstractStage end
57+
struct PreActStage <: AbstractStage end
58+
struct PostActStage <: AbstractStage end
59+
60+
const PRE_EXPERIMENT_STAGE = PreExperimentStage()
61+
const POST_EXPERIMENT_STAGE = PostExperimentStage()
62+
const PRE_EPISODE_STAGE = PreEpisodeStage()
63+
const POST_EPISODE_STAGE = PostEpisodeStage()
64+
const PRE_ACT_STAGE = PreActStage()
65+
const POST_ACT_STAGE = PostActStage()
66+
67+
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs)
68+
function (agent::AbstractAgent)(stage::AbstractStage, obs) end

src/components/agents/agent.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ Generally speaking, it does nothing but update the trajectory and policy appropr
1111
1212
- `policy`::[`AbstractPolicy`](@ref): the policy to use
1313
- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment
14-
- `role=DEFAULT_PLAYER`: used to distinguish different agents
14+
- `role=:DEFAULT_PLAYER`: used to distinguish different agents
1515
"""
1616
Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent
1717
policy::P
1818
trajectory::T
19-
role::R = DEFAULT_PLAYER
19+
role::R = :DEFAULT_PLAYER
2020
end
2121

22-
RLBase.get_role(agent::Agent) = agent.role
22+
get_role(agent::Agent) = agent.role
2323

2424
#####
2525
# EpisodicCompactSARTSATrajectory

src/components/agents/agents.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
include("abstract_agent.jl")
12
include("agent.jl")
23
include("dyna_agent.jl")

src/components/agents/dyna_agent.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ Base.@kwdef struct DynaAgent{
2727
policy::P
2828
model::M
2929
trajectory::B
30-
role::R = DEFAULT_PLAYER
30+
role::R = :DEFAULT_PLAYER
3131
plan_step::Int = 10
3232
end
3333

34-
RLBase.get_role(agent::DynaAgent) = agent.role
34+
get_role(agent::DynaAgent) = agent.role
3535

3636
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
3737
::PreEpisodeStage,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
export AbstractApproximator,
2+
ApproximatorStyle,
3+
Q_APPROXIMATOR,
4+
QApproximator,
5+
V_APPROXIMATOR,
6+
VApproximator
7+
8+
"""
9+
(app::AbstractApproximator)(obs)
10+
11+
An approximator is a functional object for value estimation.
12+
It serves as a black box to provides an abstraction over different
13+
kinds of approximate methods (for example DNN provided by Flux or Knet).
14+
"""
15+
abstract type AbstractApproximator end
16+
17+
"""
18+
update!(a::AbstractApproximator, correction)
19+
20+
Usually the `correction` is the gradient of inner parameters.
21+
"""
22+
function RLBase.update!(a::AbstractApproximator, correction) end
23+
24+
#####
25+
# traits
26+
#####
27+
28+
abstract type AbstractApproximatorStyle end
29+
30+
"""
31+
Used to detect what an [`AbstractApproximator`](@ref) is approximating.
32+
"""
33+
function ApproximatorStyle(::AbstractApproximator) end
34+
35+
struct QApproximator <: AbstractApproximatorStyle end
36+
37+
const Q_APPROXIMATOR = QApproximator()
38+
39+
struct VApproximator <: AbstractApproximatorStyle end
40+
41+
const V_APPROXIMATOR = VApproximator()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
include("abstract_approximator.jl")
12
include("tabular_approximator.jl")
23
include("neural_network_approximator.jl")

src/components/approximators/neural_network_approximator.jl

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@ export NeuralNetworkApproximator
22

33
using Flux
44

5-
struct NeuralNetworkApproximator{T,M,O,P} <: AbstractApproximator
6-
model::M
7-
optimizer::O
8-
params::P
9-
end
10-
115
"""
126
NeuralNetworkApproximator(;kwargs)
137
@@ -18,45 +12,18 @@ Use a DNN model for value estimation.
1812
- `model`, a Flux based DNN model.
1913
- `optimizer`
2014
- `parameters=params(model)`
21-
- `kind=Q_APPROXIMATOR`, specify the type of model.
2215
"""
23-
function NeuralNetworkApproximator(;
24-
model::M,
25-
optimizer::O,
26-
parameters::P = params(model),
27-
kind = Q_APPROXIMATOR,
28-
) where {M,O,P}
29-
NeuralNetworkApproximator{kind,M,O,P}(model, optimizer, parameters)
16+
Base.@kwdef struct NeuralNetworkApproximator{M,O,P} <: AbstractApproximator
17+
model::M
18+
optimizer::O
19+
params::P = params(model)
3020
end
3121

32-
device(app::NeuralNetworkApproximator) = device(app.model)
22+
(app::NeuralNetworkApproximator)(x) = app.model(x)
3323

3424
Flux.params(app::NeuralNetworkApproximator) = app.params
3525

36-
(app::NeuralNetworkApproximator)(s::AbstractArray) = app.model(s)
37-
(app::NeuralNetworkApproximator{Q_APPROXIMATOR})(s::AbstractArray, a::Int) = app.model(s)[a]
38-
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:Q}) =
39-
app.model(s, Val(:Q))
40-
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:V}) =
41-
app.model(s, Val(:V))
42-
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, a::Int) =
43-
app.model(s, Val(:Q))[a]
44-
45-
46-
RLBase.batch_estimate(app::NeuralNetworkApproximator, states::AbstractArray) =
47-
app.model(states)
48-
49-
RLBase.batch_estimate(
50-
app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR},
51-
states::AbstractArray,
52-
::Val{:Q},
53-
) = app.model(states, Val(:Q))
54-
55-
RLBase.batch_estimate(
56-
app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR},
57-
states::AbstractArray,
58-
::Val{:V},
59-
) = app.model(states, Val(:V))
26+
device(app::NeuralNetworkApproximator) = device(app.model)
6027

6128
RLBase.update!(app::NeuralNetworkApproximator, gs) =
6229
Flux.Optimise.update!(app.optimizer, app.params, gs)

src/components/approximators/tabular_approximator.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ export TabularApproximator
33
"""
44
TabularApproximator(table<:AbstractArray)
55
6-
For `table` of 1-d, it will create a [`V_APPROXIMATOR`](@ref). For `table` of 2-d, it will create a [`QApproximator`].
6+
For `table` of 1-d, it will serve as a state value approximator.
7+
For `table` of 2-d, it will serve as a state-action value approximator.
78
89
!!! warning
910
For `table` of 2-d, the first dimension is action and the second dimension is state.
@@ -47,5 +48,5 @@ function RLBase.update!(Q::TabularApproximator{2}, correction::Pair{Int,Vector{F
4748
end
4849
end
4950

50-
RLBase.ApproximatorStyle(::TabularApproximator{1}) = VApproximator()
51-
RLBase.ApproximatorStyle(::TabularApproximator{2}) = QApproximator()
51+
ApproximatorStyle(::TabularApproximator{1}) = V_APPROXIMATOR
52+
ApproximatorStyle(::TabularApproximator{2}) = Q_APPROXIMATOR

src/components/components.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
include("learners/learners.jl")
2-
include("policies/policies.jl")
1+
include("trajectories/trajectories.jl")
32
include("approximators/approximators.jl")
43
include("explorers/explorers.jl")
5-
include("trajectories/trajectories.jl")
6-
include("preprocessors.jl")
4+
include("learners/learners.jl")
5+
include("policies/policies.jl")
76
include("agents/agents.jl")

src/components/explorers/UCB_explorer.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,41 @@
11
export UCBExplorer
22

3+
using Random
4+
5+
mutable struct UCBExplorer{R<:AbstractRNG} <: AbstractExplorer
6+
c::Float64
7+
actioncounts::Vector{Float64}
8+
step::Int
9+
rng::R
10+
end
11+
312
"""
4-
UCBExplorer(na; c=2.0, ϵ=1e-10)
13+
UCBExplorer(na; c=2.0, ϵ=1e-10, step=1, seed=nothing)
14+
515
# Arguments
616
- `na` is the number of actions used to create a internal counter.
717
- `t` is used to store current time step.
818
- `c` is used to control the degree of exploration.
19+
- `seed`, set the seed of inner RNG.
920
"""
10-
mutable struct UCBExplorer <: AbstractExplorer
11-
c::Float64
12-
actioncounts::Vector{Float64}
13-
step::Int
14-
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1) = new(c, fill(ϵ, na), 1)
15-
end
21+
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed=nothing) = UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed))
1622

1723
@doc raw"""
1824
(ucb::UCBExplorer)(values::AbstractArray)
1925
Unlike [`EpsilonGreedyExplorer`](@ref), uncertaintyies are considered in UCB.
26+
2027
!!! note
2128
If multiple values with the same maximum value are found.
2229
Then a random one will be returned!
30+
2331
```math
2432
A_t = \underset{a}{\arg \max} \left[ Q_t(a) + c \sqrt{\frac{\ln t}{N_t(a)}} \right]
2533
```
34+
2635
See more details at Section (2.7) on Page 35 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
2736
""" function (p::UCBExplorer)(values::AbstractArray)
28-
action =
29-
find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts))[2] |> sample
37+
v, inds = find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts))
38+
action = sample(p.rng, inds)
3039
p.actioncounts[action] += 1
3140
p.step += 1
3241
action
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
export AbstractExplorer
2+
3+
"""
4+
(p::AbstractExplorer)(x)
5+
(p::AbstractExplorer)(x, mask)
6+
7+
Define how to select an action based on action values.
8+
"""
9+
abstract type AbstractExplorer end
10+
11+
function (p::AbstractExplorer)(x) end
12+
function (p::AbstractExplorer)(x, mask) end
13+
14+
"""
15+
get_prob(p::AbstractExplorer, x) -> AbstractDistribution
16+
17+
Get the action distribution given action values.
18+
"""
19+
function RLBase.get_prob(p::AbstractExplorer, x) end
20+
21+
"""
22+
get_prob(p::AbstractExplorer, x, mask)
23+
24+
Similart to `get_prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered.
25+
"""
26+
function RLBase.get_prob(p::AbstractExplorer, x, mask) end

src/components/explorers/batch_exploer.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@ export BatchExplorer
22

33
"""
44
BatchExplorer(explorer::AbstractExplorer)
5-
BatchExplorer(explorers::AbstractExplorer...)
65
"""
76
struct BatchExplorer{E} <: AbstractExplorer
87
explorer::E
98
end
109

1110
BatchExplorer(explorers::AbstractExplorer...) = BatchExplorer(explorers)
1211

12+
"""
13+
(x::BatchExplorer)(values::AbstractMatrix)
14+
15+
Apply inner explorer to each column of `values`.
16+
"""
1317
(x::BatchExplorer)(values::AbstractMatrix) = [x.explorer(v) for v in eachcol(values)]
18+
1419
(x::BatchExplorer{<:Tuple})(values::AbstractMatrix) =
1520
[explorer(v) for (explorer, v) in zip(x.explorer, eachcol(values))]

0 commit comments

Comments
 (0)