Skip to content

Commit dcbfbc9

Browse files
Implementing Rainbow model.
1 parent 56aa154 commit dcbfbc9

18 files changed

+434
-43
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ that are ready to run and easy to modify for other similar usecases:
117117
- N-Step Bellman updates
118118
- Distributional Q-Learning
119119
- Noisy Networks for Exploration
120+
- Rainbow (combination of the above)
120121

121122

122123
# Examples

examples-configs/rl/atari/dqn_rainbow_param/asterix_rp_dqn_distributional.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ reinforcer:
3939
replay_buffer:
4040
name: vel.rl.buffers.circular_replay_buffer
4141

42-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
43-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
42+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4443
buffer_capacity: 1_000_000
4544

4645
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer

examples-configs/rl/atari/dqn_rainbow_param/asterix_rp_dqn_raw.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ reinforcer:
3636
replay_buffer:
3737
name: vel.rl.buffers.circular_replay_buffer
3838

39-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
40-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
39+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4140
buffer_capacity: 1_000_000
4241

4342
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer

examples-configs/rl/atari/dqn_rainbow_param/asteroids_rp_dqn_noisynet.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ reinforcer:
4343
replay_buffer:
4444
name: vel.rl.buffers.circular_replay_buffer
4545

46-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
47-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
46+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4847
buffer_capacity: 1_000_000
4948

5049
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer

examples-configs/rl/atari/dqn_rainbow_param/asteroids_rp_dqn_raw.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ reinforcer:
3636
replay_buffer:
3737
name: vel.rl.buffers.circular_replay_buffer
3838

39-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
40-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
39+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4140
buffer_capacity: 1_000_000
4241

4342
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer

examples-configs/rl/atari/dqn_rainbow_param/atlantis_rp_dqn_nstep.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ reinforcer:
4040
replay_buffer:
4141
name: vel.rl.buffers.circular_replay_buffer
4242

43-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
44-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
43+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4544
buffer_capacity: 1_000_000
4645

4746
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer

examples-configs/rl/atari/dqn_rainbow_param/atlantis_rp_dqn_raw.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ reinforcer:
3636
replay_buffer:
3737
name: vel.rl.buffers.circular_replay_buffer
3838

39-
# buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
40-
buffer_initial_size: 200_000 # How many samples we need in the buffer before we start using replay buffer
39+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
4140
buffer_capacity: 1_000_000
4241

4342
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
name: 'breakout_rainbow'
2+
3+
4+
env:
5+
name: vel.rl.env.classic_atari
6+
game: 'BreakoutNoFrameskip-v4'
7+
settings:
8+
max_episode_frames: 108_000
9+
10+
11+
vec_env:
12+
name: vel.rl.vecenv.dummy
13+
frame_history: 4 # How many stacked frames go into a single observation
14+
15+
16+
model:
17+
name: vel.rl.models.q_rainbow_model
18+
19+
atoms: 51 # 51 bins for Distributional DQN
20+
vmin: -10.0
21+
vmax: 10.0
22+
23+
initial_std_dev: 0.5
24+
factorized_noise: true
25+
26+
input_block:
27+
name: vel.modules.input.image_to_tensor
28+
29+
backbone:
30+
name: vel.rl.models.backbone.double_noisy_nature_cnn
31+
input_width: 84
32+
input_height: 84
33+
input_channels: 4 # The same as frame_history
34+
35+
initial_std_dev: 0.5
36+
factorized_noise: true
37+
38+
39+
reinforcer:
40+
name: vel.rl.reinforcers.buffered_off_policy_iteration_reinforcer
41+
42+
env_roller:
43+
name: vel.rl.env_roller.transition_replay_env_roller
44+
45+
# N-Step Q-Learning
46+
forward_steps: 3
47+
discount_factor: 0.99
48+
49+
replay_buffer:
50+
name: vel.rl.buffers.prioritized_circular_replay_buffer
51+
52+
buffer_initial_size: 80_000 # How many samples we need in the buffer before we start using replay buffer
53+
buffer_capacity: 1_000_000
54+
55+
# Because env has a framestack already built-in, save memory by encoding only last frames in the replay buffer
56+
frame_stack_compensation: true
57+
frame_history: 4 # How many stacked frames go into a single observation
58+
59+
priority_exponent: 0.5
60+
priority_weight:
61+
name: vel.schedules.linear
62+
initial_value: 0.4
63+
final_value: 1.0
64+
65+
priority_epsilon: 1.0e-6
66+
67+
algo:
68+
name: vel.rl.algo.distributional_dqn
69+
double_dqn: true
70+
71+
target_update_frequency: 32_000 # After how many batches to update the target network
72+
max_grad_norm: 0.5
73+
74+
discount_factor: 0.99
75+
76+
rollout_steps: 4 # How many environment steps (per env) to perform per batch of training
77+
training_steps: 32 # How many environment steps (per env) to perform per training round
78+
parallel_envs: 1 # Roll out only one env in parallel, just like in DeepMind paper
79+
80+
81+
optimizer:
82+
name: vel.optimizers.adam
83+
lr: 6.25e-05
84+
epsilon: 1.5e-4
85+
86+
87+
commands:
88+
train:
89+
name: vel.rl.commands.rl_train_command
90+
total_frames: 1.1e7 # 11M
91+
batches_per_epoch: 2500
92+
93+
record:
94+
name: vel.rl.commands.record_movie_command
95+
takes: 10
96+
videoname: 'breakout_rainbow_vid_{:04}.avi'
97+
fps: 15
98+
99+
evaluate:
100+
name: vel.rl.commands.evaluate_env_command
101+
parallel_envs: 12
102+
takes: 20

vel/rl/algo/distributional_dqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def metrics(self) -> list:
180180

181181

182182
def create(model: ModelFactory, discount_factor: float, target_update_frequency: int,
183-
max_grad_norm: float, double_dqn: bool=False):
183+
max_grad_norm: float, double_dqn: bool = False):
184184
""" Vel factory function """
185185
return DistributionalDeepQLearning(
186186
model_factory=model,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Code based loosely on implementation:
3+
https://github.com/openai/baselines/blob/master/baselines/ppo2/policies.py
4+
5+
Under MIT license.
6+
"""
7+
import numpy as np
8+
9+
import torch.nn as nn
10+
import torch.nn.init as init
11+
import torch.nn.functional as F
12+
13+
import vel.util.network as net_util
14+
15+
from vel.api import LinearBackboneModel, ModelFactory
16+
from vel.rl.modules.noisy_linear import NoisyLinear
17+
18+
19+
class DoubleNoisyNatureCnn(LinearBackboneModel):
20+
"""
21+
Neural network as defined in the paper 'Human-level control through deep reinforcement learning'
22+
but with two separate heads and "noisy" linear layer.
23+
"""
24+
def __init__(self, input_width, input_height, input_channels, output_dim=512, initial_std_dev=0.4,
25+
factorized_noise=True):
26+
super().__init__()
27+
28+
self._output_dim = output_dim
29+
30+
self.conv1 = nn.Conv2d(
31+
in_channels=input_channels,
32+
out_channels=32,
33+
kernel_size=(8, 8),
34+
stride=4
35+
)
36+
37+
self.conv2 = nn.Conv2d(
38+
in_channels=32,
39+
out_channels=64,
40+
kernel_size=(4, 4),
41+
stride=2
42+
)
43+
44+
self.conv3 = nn.Conv2d(
45+
in_channels=64,
46+
out_channels=64,
47+
kernel_size=(3, 3),
48+
stride=1
49+
)
50+
51+
self.final_width = net_util.convolutional_layer_series(input_width, [
52+
(8, 0, 4),
53+
(4, 0, 2),
54+
(3, 0, 1)
55+
])
56+
57+
self.final_height = net_util.convolutional_layer_series(input_height, [
58+
(8, 0, 4),
59+
(4, 0, 2),
60+
(3, 0, 1)
61+
])
62+
63+
self.linear_layer_one = NoisyLinear(
64+
# 64 is the number of channels of the last conv layer
65+
self.final_width * self.final_height * 64,
66+
self.output_dim,
67+
initial_std_dev=initial_std_dev,
68+
factorized_noise=factorized_noise
69+
)
70+
71+
self.linear_layer_two = NoisyLinear(
72+
# 64 is the number of channels of the last conv layer
73+
self.final_width * self.final_height * 64,
74+
self.output_dim,
75+
initial_std_dev=initial_std_dev,
76+
factorized_noise=factorized_noise
77+
)
78+
79+
@property
80+
def output_dim(self) -> int:
81+
""" Final dimension of model output """
82+
return self._output_dim
83+
84+
def reset_weights(self):
85+
for m in self.modules():
86+
if isinstance(m, nn.Conv2d):
87+
# init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88+
init.orthogonal_(m.weight, gain=np.sqrt(2))
89+
init.constant_(m.bias, 0.0)
90+
elif isinstance(m, nn.Linear):
91+
# init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92+
init.orthogonal_(m.weight, gain=np.sqrt(2))
93+
init.constant_(m.bias, 0.0)
94+
elif isinstance(m, NoisyLinear):
95+
m.reset_weights()
96+
97+
def forward(self, image):
98+
result = image
99+
result = F.relu(self.conv1(result))
100+
result = F.relu(self.conv2(result))
101+
result = F.relu(self.conv3(result))
102+
flattened = result.view(result.size(0), -1)
103+
104+
output_one = F.relu(self.linear_layer_one(flattened))
105+
output_two = F.relu(self.linear_layer_two(flattened))
106+
107+
return output_one, output_two
108+
109+
110+
def create(input_width, input_height, input_channels=1, output_dim=512, initial_std_dev=0.4, factorized_noise=True):
111+
""" Vel factory function """
112+
def instantiate(**_):
113+
return DoubleNoisyNatureCnn(
114+
input_width=input_width, input_height=input_height, input_channels=input_channels,
115+
output_dim=output_dim, initial_std_dev=initial_std_dev, factorized_noise=factorized_noise
116+
)
117+
118+
return ModelFactory.generic(instantiate)
119+
120+
121+
DoubleNoisyNatureCnnFactory = create

vel/rl/models/q_dueling_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from vel.api import LinearBackboneModel, Model, ModelFactory, BackboneModel
55
from vel.modules.input.identity import IdentityFactory
66
from vel.rl.api import Rollout, Evaluator
7-
from vel.rl.modules.dueling_q_head import DuelingQHead
7+
from vel.rl.modules.q_dueling_head import QDuelingHead
88
from vel.rl.models.q_model import QModelEvaluator
99

1010

@@ -21,7 +21,7 @@ def __init__(self, input_block: BackboneModel, backbone: LinearBackboneModel, ac
2121

2222
self.input_block = input_block
2323
self.backbone = backbone
24-
self.q_head = DuelingQHead(input_dim=backbone.output_dim, action_space=action_space)
24+
self.q_head = QDuelingHead(input_dim=backbone.output_dim, action_space=action_space)
2525

2626
def forward(self, observations):
2727
""" Model forward pass """

vel/rl/models/q_noisy_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vel.modules.input.identity import IdentityFactory
66
from vel.rl.api import Rollout, RlModel, Evaluator
77
from vel.rl.models.q_model import QModelEvaluator
8-
from vel.rl.modules.noisy_q_head import NoisyQHead
8+
from vel.rl.modules.q_noisy_head import QNoisyHead
99

1010

1111
class NoisyQModel(RlModel):
@@ -22,7 +22,7 @@ def __init__(self, input_block: BackboneModel, backbone: LinearBackboneModel, ac
2222

2323
self.input_block = input_block
2424
self.backbone = backbone
25-
self.q_head = NoisyQHead(
25+
self.q_head = QNoisyHead(
2626
input_dim=backbone.output_dim, action_space=action_space, initial_std_dev=initial_std_dev,
2727
factorized_noise=factorized_noise
2828
)

0 commit comments

Comments
 (0)