-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01_dqn_basic.py
executable file
·64 lines (52 loc) · 2.01 KB
/
01_dqn_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
import gym
import ptan
import argparse
import random
import torch
import torch.optim as optim
from ignite.engine import Engine
from lib import dqn_model, common
NAME = "01_baseline"
if __name__ == "__main__":
random.seed(common.SEED)
torch.manual_seed(common.SEED)
params = common.HYPERPARAMS['pong']
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
env = gym.make(params.env_name)
env = ptan.common.wrappers.wrap_dqn(env)
env.seed(common.SEED)
net = dqn_model.DQN(env.observation_space.shape,
env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)
selector = ptan.actions.EpsilonGreedyActionSelector(
epsilon=params.epsilon_start)
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(
env, agent, gamma=params.gamma)
buffer = ptan.experience.ExperienceReplayBuffer(
exp_source, buffer_size=params.replay_size)
optimizer = optim.Adam(net.parameters(),
lr=params.learning_rate)
def process_batch(engine, batch):
optimizer.zero_grad()
loss_v = common.calc_loss_dqn(
batch, net, tgt_net.target_model,
gamma=params.gamma, device=device)
loss_v.backward()
optimizer.step()
epsilon_tracker.frame(engine.state.iteration)
if engine.state.iteration % params.target_net_sync == 0:
tgt_net.sync()
return {
"loss": loss_v.item(),
"epsilon": selector.epsilon,
}
engine = Engine(process_batch)
common.setup_ignite(engine, params, exp_source, NAME)
engine.run(common.batch_generator(buffer, params.replay_initial,
params.batch_size))