Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Simplify APIs #47

Merged
merged 13 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Flux = "0.10"
GPUArrays = "2, 3"
MacroTools = "0.5"
ProgressMeter = "1.2"
ReinforcementLearningBase = "0.6"
ReinforcementLearningBase = "0.7"
StatsBase = "0.32"
Zygote = "0.4"
julia = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export RLCore

include("extensions/extensions.jl")
include("utils/utils.jl")
include("core/core.jl")
include("components/components.jl")
include("core/core.jl")

end # module
68 changes: 68 additions & 0 deletions src/components/agents/abstract_agent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
export AbstractAgent,
get_role,
PreExperimentStage,
PostExperimentStage,
PreEpisodeStage,
PostEpisodeStage,
PreActStage,
PostActStage,
PRE_EXPERIMENT_STAGE,
POST_EXPERIMENT_STAGE,
PRE_EPISODE_STAGE,
POST_EPISODE_STAGE,
PRE_ACT_STAGE,
POST_ACT_STAGE

"""
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action
(agent::AbstractAgent)(stage::AbstractStage, obs)

Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action.
The main difference is that, we divide an experiment into the following stages:

- `PRE_EXPERIMENT_STAGE`
- `PRE_EPISODE_STAGE`
- `PRE_ACT_STAGE`
- `POST_ACT_STAGE`
- `POST_EPISODE_STAGE`
- `POST_EXPERIMENT_STAGE`

In each stage, different types of agents may have different behaviors, like updating experience buffer, environment model or policy.
"""
abstract type AbstractAgent end

function get_role(::AbstractAgent) end

"""
+-----------------------------------------------------------+
|Episode |
| |
PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE | POST_EXPERIMENT_STAGE
| | | | | |
v | +-----+ v +-------+ v +-----+ | v
--------------------->+ env +------>+ agent +------->+ env +---> ... ------->......
| ^ +-----+ obs +-------+ action +-----+ ^ |
| | | |
| +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ |
| |
| |
+-----------------------------------------------------------+
"""
abstract type AbstractStage end

struct PreExperimentStage <: AbstractStage end
struct PostExperimentStage <: AbstractStage end
struct PreEpisodeStage <: AbstractStage end
struct PostEpisodeStage <: AbstractStage end
struct PreActStage <: AbstractStage end
struct PostActStage <: AbstractStage end

const PRE_EXPERIMENT_STAGE = PreExperimentStage()
const POST_EXPERIMENT_STAGE = PostExperimentStage()
const PRE_EPISODE_STAGE = PreEpisodeStage()
const POST_EPISODE_STAGE = PostEpisodeStage()
const PRE_ACT_STAGE = PreActStage()
const POST_ACT_STAGE = PostActStage()

(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs)
function (agent::AbstractAgent)(stage::AbstractStage, obs) end
6 changes: 3 additions & 3 deletions src/components/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ Generally speaking, it does nothing but update the trajectory and policy appropr

- `policy`::[`AbstractPolicy`](@ref): the policy to use
- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment
- `role=DEFAULT_PLAYER`: used to distinguish different agents
- `role=:DEFAULT_PLAYER`: used to distinguish different agents
"""
Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent
policy::P
trajectory::T
role::R = DEFAULT_PLAYER
role::R = :DEFAULT_PLAYER
end

RLBase.get_role(agent::Agent) = agent.role
get_role(agent::Agent) = agent.role

#####
# EpisodicCompactSARTSATrajectory
Expand Down
1 change: 1 addition & 0 deletions src/components/agents/agents.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("abstract_agent.jl")
include("agent.jl")
include("dyna_agent.jl")
4 changes: 2 additions & 2 deletions src/components/agents/dyna_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ Base.@kwdef struct DynaAgent{
policy::P
model::M
trajectory::B
role::R = DEFAULT_PLAYER
role::R = :DEFAULT_PLAYER
plan_step::Int = 10
end

RLBase.get_role(agent::DynaAgent) = agent.role
get_role(agent::DynaAgent) = agent.role

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::PreEpisodeStage,
Expand Down
41 changes: 41 additions & 0 deletions src/components/approximators/abstract_approximator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
export AbstractApproximator,
ApproximatorStyle,
Q_APPROXIMATOR,
QApproximator,
V_APPROXIMATOR,
VApproximator

"""
(app::AbstractApproximator)(obs)

An approximator is a functional object for value estimation.
It serves as a black box to provides an abstraction over different
kinds of approximate methods (for example DNN provided by Flux or Knet).
"""
abstract type AbstractApproximator end

"""
update!(a::AbstractApproximator, correction)

Usually the `correction` is the gradient of inner parameters.
"""
function RLBase.update!(a::AbstractApproximator, correction) end

#####
# traits
#####

abstract type AbstractApproximatorStyle end

"""
Used to detect what an [`AbstractApproximator`](@ref) is approximating.
"""
function ApproximatorStyle(::AbstractApproximator) end

struct QApproximator <: AbstractApproximatorStyle end

const Q_APPROXIMATOR = QApproximator()

struct VApproximator <: AbstractApproximatorStyle end

const V_APPROXIMATOR = VApproximator()
1 change: 1 addition & 0 deletions src/components/approximators/approximators.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("abstract_approximator.jl")
include("tabular_approximator.jl")
include("neural_network_approximator.jl")
45 changes: 6 additions & 39 deletions src/components/approximators/neural_network_approximator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@ export NeuralNetworkApproximator

using Flux

struct NeuralNetworkApproximator{T,M,O,P} <: AbstractApproximator
model::M
optimizer::O
params::P
end

"""
NeuralNetworkApproximator(;kwargs)

Expand All @@ -18,45 +12,18 @@ Use a DNN model for value estimation.
- `model`, a Flux based DNN model.
- `optimizer`
- `parameters=params(model)`
- `kind=Q_APPROXIMATOR`, specify the type of model.
"""
function NeuralNetworkApproximator(;
model::M,
optimizer::O,
parameters::P = params(model),
kind = Q_APPROXIMATOR,
) where {M,O,P}
NeuralNetworkApproximator{kind,M,O,P}(model, optimizer, parameters)
Base.@kwdef struct NeuralNetworkApproximator{M,O,P} <: AbstractApproximator
model::M
optimizer::O
params::P = params(model)
end

device(app::NeuralNetworkApproximator) = device(app.model)
(app::NeuralNetworkApproximator)(x) = app.model(x)

Flux.params(app::NeuralNetworkApproximator) = app.params

(app::NeuralNetworkApproximator)(s::AbstractArray) = app.model(s)
(app::NeuralNetworkApproximator{Q_APPROXIMATOR})(s::AbstractArray, a::Int) = app.model(s)[a]
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:Q}) =
app.model(s, Val(:Q))
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, ::Val{:V}) =
app.model(s, Val(:V))
(app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR})(s::AbstractArray, a::Int) =
app.model(s, Val(:Q))[a]


RLBase.batch_estimate(app::NeuralNetworkApproximator, states::AbstractArray) =
app.model(states)

RLBase.batch_estimate(
app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR},
states::AbstractArray,
::Val{:Q},
) = app.model(states, Val(:Q))

RLBase.batch_estimate(
app::NeuralNetworkApproximator{HYBRID_APPROXIMATOR},
states::AbstractArray,
::Val{:V},
) = app.model(states, Val(:V))
device(app::NeuralNetworkApproximator) = device(app.model)

RLBase.update!(app::NeuralNetworkApproximator, gs) =
Flux.Optimise.update!(app.optimizer, app.params, gs)
Expand Down
7 changes: 4 additions & 3 deletions src/components/approximators/tabular_approximator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ export TabularApproximator
"""
TabularApproximator(table<:AbstractArray)

For `table` of 1-d, it will create a [`V_APPROXIMATOR`](@ref). For `table` of 2-d, it will create a [`QApproximator`].
For `table` of 1-d, it will serve as a state value approximator.
For `table` of 2-d, it will serve as a state-action value approximator.

!!! warning
For `table` of 2-d, the first dimension is action and the second dimension is state.
Expand Down Expand Up @@ -47,5 +48,5 @@ function RLBase.update!(Q::TabularApproximator{2}, correction::Pair{Int,Vector{F
end
end

RLBase.ApproximatorStyle(::TabularApproximator{1}) = VApproximator()
RLBase.ApproximatorStyle(::TabularApproximator{2}) = QApproximator()
ApproximatorStyle(::TabularApproximator{1}) = V_APPROXIMATOR
ApproximatorStyle(::TabularApproximator{2}) = Q_APPROXIMATOR
7 changes: 3 additions & 4 deletions src/components/components.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
include("learners/learners.jl")
include("policies/policies.jl")
include("trajectories/trajectories.jl")
include("approximators/approximators.jl")
include("explorers/explorers.jl")
include("trajectories/trajectories.jl")
include("preprocessors.jl")
include("learners/learners.jl")
include("policies/policies.jl")
include("agents/agents.jl")
27 changes: 18 additions & 9 deletions src/components/explorers/UCB_explorer.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
export UCBExplorer

using Random

mutable struct UCBExplorer{R<:AbstractRNG} <: AbstractExplorer
c::Float64
actioncounts::Vector{Float64}
step::Int
rng::R
end

"""
UCBExplorer(na; c=2.0, ϵ=1e-10)
UCBExplorer(na; c=2.0, ϵ=1e-10, step=1, seed=nothing)

# Arguments
- `na` is the number of actions used to create a internal counter.
- `t` is used to store current time step.
- `c` is used to control the degree of exploration.
- `seed`, set the seed of inner RNG.
"""
mutable struct UCBExplorer <: AbstractExplorer
c::Float64
actioncounts::Vector{Float64}
step::Int
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1) = new(c, fill(ϵ, na), 1)
end
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed=nothing) = UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed))

@doc raw"""
(ucb::UCBExplorer)(values::AbstractArray)
Unlike [`EpsilonGreedyExplorer`](@ref), uncertaintyies are considered in UCB.

!!! note
If multiple values with the same maximum value are found.
Then a random one will be returned!

```math
A_t = \underset{a}{\arg \max} \left[ Q_t(a) + c \sqrt{\frac{\ln t}{N_t(a)}} \right]
```

See more details at Section (2.7) on Page 35 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
""" function (p::UCBExplorer)(values::AbstractArray)
action =
find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts))[2] |> sample
v, inds = find_all_max(@. values + p.c * sqrt(log(p.step + 1) / p.actioncounts))
action = sample(p.rng, inds)
p.actioncounts[action] += 1
p.step += 1
action
Expand Down
26 changes: 26 additions & 0 deletions src/components/explorers/abstract_explorer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
export AbstractExplorer

"""
(p::AbstractExplorer)(x)
(p::AbstractExplorer)(x, mask)

Define how to select an action based on action values.
"""
abstract type AbstractExplorer end

function (p::AbstractExplorer)(x) end
function (p::AbstractExplorer)(x, mask) end

"""
get_prob(p::AbstractExplorer, x) -> AbstractDistribution

Get the action distribution given action values.
"""
function RLBase.get_prob(p::AbstractExplorer, x) end

"""
get_prob(p::AbstractExplorer, x, mask)

Similart to `get_prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered.
"""
function RLBase.get_prob(p::AbstractExplorer, x, mask) end
7 changes: 6 additions & 1 deletion src/components/explorers/batch_exploer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@ export BatchExplorer

"""
BatchExplorer(explorer::AbstractExplorer)
BatchExplorer(explorers::AbstractExplorer...)
"""
struct BatchExplorer{E} <: AbstractExplorer
explorer::E
end

BatchExplorer(explorers::AbstractExplorer...) = BatchExplorer(explorers)

"""
(x::BatchExplorer)(values::AbstractMatrix)

Apply inner explorer to each column of `values`.
"""
(x::BatchExplorer)(values::AbstractMatrix) = [x.explorer(v) for v in eachcol(values)]

(x::BatchExplorer{<:Tuple})(values::AbstractMatrix) =
[explorer(v) for (explorer, v) in zip(x.explorer, eachcol(values))]
Loading