Skip to content

Commit d660085

Browse files
committed
added experiments to test
adjusted tests and merged main
1 parent 5df8e66 commit d660085

File tree

5 files changed

+85
-8
lines changed

5 files changed

+85
-8
lines changed

src/ReinforcementLearningExperiments/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ version = "0.3.1"
66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1213
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
1314
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
1415
ReinforcementLearningZoo = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
16+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1517
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1618
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"
1719

@@ -29,7 +31,9 @@ julia = "1.9"
2931

3032
[extras]
3133
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
35+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3236
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3337

3438
[targets]
35-
test = ["CUDA", "Test"]
39+
test = ["CUDA", "PyCall", "Test"]

src/ReinforcementLearningExperiments/src/experiments/MARL/DQN_mpe_simple.jl renamed to src/ReinforcementLearningExperiments/deps/experiments/experiments/MARL/DQN_mpe_simple.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ using Flux.Losses: huber_loss
1717
function RLCore.Experiment(
1818
::Val{:JuliaRL},
1919
::Val{:DQN},
20-
::Val{:MPESimple};
20+
::Val{:MPESimple},
2121
seed=123,
2222
n=1,
2323
γ=0.99f0,
2424
is_enable_double_DQN=true
2525
)
2626
rng = StableRNG(seed)
27-
env = discrete2standard_discrete(PettingzooEnv("mpe.simple_v2"; seed=seed))
27+
env = discrete2standard_discrete(PettingZooEnv("mpe.simple_v2"; seed=seed))
2828
ns, na = length(state(env)), length(action_space(env))
2929

3030
agent = Agent(

src/ReinforcementLearningExperiments/src/experiments/MARL/IDQN_TicTacToe.jl renamed to src/ReinforcementLearningExperiments/deps/experiments/experiments/MARL/IDQN_TicTacToe.jl

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# ---
88

99
using StableRNGs
10-
using ReinforcementLearning
1110
using ReinforcementLearningBase
1211
using ReinforcementLearningZoo
1312
using ReinforcementLearningCore
@@ -16,7 +15,6 @@ using Flux
1615
using Flux.Losses: huber_loss
1716
using Flux: glorot_uniform
1817

19-
using ProgressMeter
2018

2119

2220
rng = StableRNG(1234)
@@ -25,6 +23,71 @@ cap = 100
2523

2624
RLCore.forward(L::DQNLearner, state::A) where {A <: Real} = RLCore.forward(L, [state])
2725

26+
27+
episodes_per_step = 25
28+
29+
function RLCore.Experiment(
30+
::Val{:JuliaRL},
31+
::Val{:IDQN},
32+
::Val{:TicTacToe},
33+
seed=123,
34+
n=1,
35+
γ=0.99f0,
36+
is_enable_double_DQN=true
37+
)
38+
rng = StableRNG(seed)
39+
create_policy() = QBasedPolicy(
40+
learner=DQNLearner(
41+
approximator=Approximator(
42+
model=TwinNetwork(
43+
Chain(
44+
Dense(1, 512, relu; init=glorot_uniform(rng)),
45+
Dense(512, 256, relu; init=glorot_uniform(rng)),
46+
Dense(256, 9; init=glorot_uniform(rng)),
47+
);
48+
sync_freq=100
49+
),
50+
optimiser=Adam(),
51+
),
52+
n=n,
53+
γ=γ,
54+
is_enable_double_DQN=is_enable_double_DQN,
55+
loss_func=huber_loss,
56+
rng=rng,
57+
),
58+
explorer=EpsilonGreedyExplorer(
59+
kind=:exp,
60+
ϵ_stable=0.01,
61+
decay_steps=500,
62+
rng=rng,
63+
),
64+
)
65+
66+
e = TicTacToeEnv();
67+
m = MultiAgentPolicy(NamedTuple((player =>
68+
Agent(player != :Cross ? create_policy() : RandomPolicy(;rng=rng),
69+
Trajectory(
70+
container=CircularArraySARTTraces(
71+
capacity=cap,
72+
state=Integer => (1,),
73+
),
74+
sampler=NStepBatchSampler{SS′ART}(
75+
n=n,
76+
γ=γ,
77+
batch_size=1,
78+
rng=rng
79+
),
80+
controller=InsertSampleRatioController(
81+
threshold=1,
82+
n_inserted=0
83+
))
84+
)
85+
for player in players(e)))
86+
);
87+
hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p players(e))))
88+
Experiment(m, e, StopAfterEpisode(episodes_per_step), hooks)
89+
end
90+
2891
create_policy() = QBasedPolicy(
2992
learner=DQNLearner(
3093
approximator=Approximator(
@@ -36,7 +99,7 @@ create_policy() = QBasedPolicy(
3699
);
37100
sync_freq=100
38101
),
39-
optimiser=ADAM(),
102+
optimiser=Adam(),
40103
),
41104
n=32,
42105
γ=0.99f0,
@@ -75,9 +138,8 @@ m = MultiAgentPolicy(NamedTuple((player =>
75138
);
76139
hooks = MultiAgentHook(NamedTuple((p => TotalRewardPerEpisode() for p players(e))))
77140

78-
episodes_per_step = 25
79141
win_rates = (Cross=Float64[], Nought=Float64[])
80-
@showprogress for i 1:2
142+
for i 1:2
81143
run(m, e, StopAfterEpisode(episodes_per_step; is_show_progress=false), hooks)
82144
wr_cross = sum(hooks[:Cross].rewards)/(i*episodes_per_step)
83145
wr_nought = sum(hooks[:Nought].rewards)/(i*episodes_per_step)

src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ReinforcementLearningExperiments
22

33
using Reexport
4+
using Requires
45

56
@reexport using ReinforcementLearningCore, ReinforcementLearningBase, ReinforcementLearningZoo
67

@@ -19,6 +20,10 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_Rainbow_CartPole.jl"))
1920
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_VPG_CartPole.jl"))
2021
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_TRPO_CartPole.jl"))
2122
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_MPO_CartPole.jl"))
23+
include(joinpath(EXPERIMENTS_DIR, "IDQN_TicTacToe.jl"))
24+
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include(
25+
joinpath(EXPERIMENTS_DIR, "DQN_mpe_simple.jl")
26+
)
2227

2328
# dynamic loading environments
2429
function __init__() end

src/ReinforcementLearningExperiments/test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using ReinforcementLearningExperiments
22
using CUDA
33

4+
using Requires
5+
6+
7+
48
CUDA.allowscalar(false)
59

610
run(E`JuliaRL_NFQ_CartPole`)
@@ -15,6 +19,8 @@ run(E`JuliaRL_VPG_CartPole`)
1519
run(E`JuliaRL_MPODiscrete_CartPole`)
1620
run(E`JuliaRL_MPOContinuous_CartPole`)
1721
run(E`JuliaRL_MPOCovariance_CartPole`)
22+
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" run(E`JuliaRL_DQN_MPESimple`)
23+
run(E`JuliaRL_IDQN_TicTacToe`)
1824
# run(E`JuliaRL_BC_CartPole`)
1925
# run(E`JuliaRL_VMPO_CartPole`)
2026
# run(E`JuliaRL_BasicDQN_MountainCar`)

0 commit comments

Comments
 (0)