Skip to content

Commit 3081b0e

Browse files
committed
Merge branch 'HotFix_dqn_example'
2 parents 444a510 + a6298d9 commit 3081b0e

File tree

1 file changed

+139
-111
lines changed

1 file changed

+139
-111
lines changed

example/dqn.py

Lines changed: 139 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,147 +1,175 @@
11
import os
22
import datetime
3-
import io
4-
import base64
53

64
import numpy as np
75

86
import gym
7+
from scipy.special import softmax
8+
99

1010
import tensorflow as tf
1111
from tensorflow.keras.models import Sequential,clone_model
12-
from tensorflow.keras.layers import InputLayer,Dense
12+
from tensorflow.keras.layers import Dense
1313
from tensorflow.keras.optimizers import Adam
14-
from tensorflow.keras.callbacks import EarlyStopping,TensorBoard
1514
from tensorflow.summary import create_file_writer
1615

17-
from scipy.special import softmax
1816

19-
from cpprb import create_buffer, ReplayBuffer,PrioritizedReplayBuffer
20-
import cpprb.gym
17+
from cpprb import ReplayBuffer,PrioritizedReplayBuffer
2118

2219

2320
gamma = 0.99
2421
batch_size = 1024
2522

26-
N_iteration = 101
27-
N_show = 10
28-
29-
per_train = 100
23+
N_iteration = int(1e+5)
24+
target_update_freq = 50
3025

3126
prioritized = True
3227

33-
egreedy = True
28+
egreedy = 0.1
3429

35-
loss = "huber_loss"
36-
# loss = "mean_squared_error"
37-
38-
a = cpprb.gym.NotebookAnimation()
30+
# Log
3931
dir_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
40-
41-
4232
logdir = os.path.join("logs", dir_name)
4333
writer = create_file_writer(logdir + "/metrics")
4434
writer.set_as_default()
4535

46-
env = gym.make('CartPole-v0')
47-
env = gym.wrappers.Monitor(env,
48-
logdir + "/video/",
49-
force=True,
50-
video_callable=(lambda ep: ep % 50 == 0))
51-
5236

53-
observation = env.reset()
37+
# Env
38+
env = gym.make('CartPole-v1')
5439

55-
model = Sequential([InputLayer(input_shape=(observation.shape)), # 4 for CartPole
56-
Dense(64,activation='relu'),
40+
# For CartPole: input 4, output 2
41+
model = Sequential([Dense(64,activation='relu',
42+
input_shape=(env.observation_space.shape)),
5743
Dense(64,activation='relu'),
58-
Dense(env.action_space.n)]) # 2 for CartPole
59-
44+
Dense(env.action_space.n)])
6045
target_model = clone_model(model)
6146

6247

48+
# Loss Function
49+
50+
@tf.function
51+
def Huber_loss(absTD):
52+
return tf.where(absTD < 1.0, absTD, tf.math.square(absTD))
53+
54+
@tf.function
55+
def MSE(absTD):
56+
return tf.math.square(absTD)
57+
58+
loss_func = Huber_loss
59+
60+
6361
optimizer = Adam()
64-
tensorboard_callback = TensorBoard(logdir, histogram_freq=1)
65-
66-
67-
model.compile(loss = loss,
68-
optimizer = optimizer,
69-
metrics=['accuracy'])
70-
71-
rb = create_buffer(1e6,
72-
{"obs":{"shape": observation.shape},
73-
"act":{"shape": 1,"dtype": np.ubyte},
74-
"rew": {},
75-
"next_obs": {"shape": observation.shape},
76-
"done": {}},
77-
prioritized = prioritized)
78-
79-
action_index = np.arange(env.action_space.n).reshape(1,-1)
80-
81-
# Bootstrap
82-
for n_episode in range (1000):
83-
observation = env.reset()
84-
for t in range(500):
85-
action = env.action_space.sample() # Random Action
86-
next_observation, reward, done, info = env.step(action)
87-
rb.add(obs=observation,
88-
act=action,
89-
rew=reward,
90-
next_obs=next_observation,
91-
done=done)
92-
observation = next_observation
93-
if done:
94-
break
95-
96-
97-
for n_episode in range(N_iteration):
98-
observation = env.reset()
99-
sum_reward = 0
100-
for t in range(500):
101-
actions = softmax(np.ravel(model.predict(observation.reshape(1,-1),
102-
batch_size=1)))
103-
actions = actions / actions.sum()
104-
105-
if egreedy:
106-
if np.random.rand() < 0.9:
107-
action = np.argmax(actions)
108-
else:
109-
action = env.action_space.sample()
110-
else:
111-
action = np.random.choice(actions.shape[0],p=actions)
112-
113-
next_observation, reward, done, info = env.step(action)
114-
sum_reward += reward
115-
rb.add(obs=observation,
116-
act=action,
117-
rew=reward,
118-
next_obs=next_observation,
119-
done=done)
120-
observation = next_observation
121-
122-
sample = rb.sample(batch_size)
123-
Q_pred = model.predict(sample["obs"])
124-
Q_true = target_model.predict(sample['next_obs']).max(axis=1,keepdims=True)*gamma*(1.0 - sample["done"]) + sample['rew']
125-
target = tf.where(tf.one_hot(tf.cast(tf.reshape(sample["act"],[-1]),
126-
dtype=tf.int32),
127-
env.action_space.n,
128-
True,False),
129-
tf.broadcast_to(Q_true,[batch_size,env.action_space.n]),
130-
Q_pred)
131-
132-
if prioritized:
133-
TD = np.square(target - Q_pred).sum(axis=1)
134-
rb.update_priorities(sample["indexes"],TD)
135-
136-
model.fit(x=sample['obs'],
137-
y=target,
138-
batch_size=batch_size,
139-
verbose = 0)
140-
141-
if done:
142-
break
143-
144-
if n_episode % 10 == 0:
62+
63+
64+
buffer_size = 1e+6
65+
env_dict = {"obs":{"shape": env.observation_space.shape},
66+
"act":{"shape": 1,"dtype": np.ubyte},
67+
"rew": {},
68+
"next_obs": {"shape": env.observation_space.shape},
69+
"done": {}}
70+
71+
if prioritized:
72+
rb = PrioritizedReplayBuffer(buffer_size,env_dict)
73+
else:
74+
rb = ReplayBuffer(buffer_size,env_dict)
75+
76+
77+
@tf.function
78+
def Q_func(model,obs,act,act_shape):
79+
return tf.reduce_sum(model(obs) * tf.one_hot(act,depth=act_shape), axis=1)
80+
81+
@tf.function
82+
def DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
83+
return gamma*tf.reduce_max(target(next_obs),axis=1)*(1.0-done) + rew
84+
85+
@tf.function
86+
def Double_DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
87+
act = tf.math.argmax(model(next_obs),axis=1)
88+
return gamma*tf.reduce_sum(target(next_obs)*tf.one_hot(act,depth=act_shape), axis=1)*(1.0-done) + rew
89+
90+
91+
target_func = DQN_target_func
92+
93+
94+
95+
# Start Experiment
96+
97+
observation = env.reset()
98+
99+
# Warming up
100+
for n_step in range(100):
101+
action = env.action_space.sample() # Random Action
102+
next_observation, reward, done, info = env.step(action)
103+
rb.add(obs=observation,
104+
act=action,
105+
rew=reward,
106+
next_obs=next_observation,
107+
done=done)
108+
observation = next_observation
109+
if done:
110+
env.reset()
111+
rb.on_episode_end()
112+
113+
114+
sum_reward = 0
115+
n_episode = 0
116+
observation = env.reset()
117+
for n_step in range(N_iteration):
118+
Q = tf.squeeze(model(observation.reshape(1,-1)))
119+
120+
if np.random.rand() < egreedy:
121+
action = env.action_space.sample()
122+
else:
123+
action = np.argmax(Q)
124+
125+
next_observation, reward, done, info = env.step(action)
126+
sum_reward += reward
127+
rb.add(obs=observation,
128+
act=action,
129+
rew=reward,
130+
next_obs=next_observation,
131+
done=done)
132+
observation = next_observation
133+
134+
sample = rb.sample(batch_size)
135+
weights = sample["weights"].ravel() if prioritized else tf.constant(1.0)
136+
137+
with tf.GradientTape() as tape:
138+
tape.watch(model.trainable_weights)
139+
Q = Q_func(model,
140+
tf.constant(sample["obs"]),
141+
tf.constant(sample["act"].ravel()),
142+
tf.constant(env.action_space.n))
143+
target_Q = target_func(model,target_model,
144+
tf.constant(sample['next_obs']),
145+
tf.constant(sample["rew"].ravel()),
146+
tf.constant(sample["done"].ravel()),
147+
tf.constant(gamma),
148+
tf.constant(env.action_space.n))
149+
absTD = tf.math.abs(target_Q - Q)
150+
loss = tf.reduce_mean(loss_func(absTD)*weights)
151+
152+
grad = tape.gradient(loss,model.trainable_weights)
153+
optimizer.apply_gradients(zip(grad,model.trainable_weights))
154+
155+
156+
if prioritized:
157+
Q = Q_func(model,
158+
tf.constant(sample["obs"]),
159+
tf.constant(sample["act"].ravel()),
160+
tf.constant(env.action_space.n))
161+
absTD = tf.math.abs(target_Q - Q)
162+
rb.update_priorities(sample["indexes"],absTD)
163+
164+
if done:
165+
env.reset()
166+
rb.on_episode_end()
167+
tf.summary.scalar("total reward vs episode",data=sum_reward,step=n_episode)
168+
tf.summary.scalar("total reward vs training step",data=sum_reward,step=n_step)
169+
sum_reward = 0
170+
n_episode += 1
171+
172+
if n_step % target_update_freq == 0:
145173
target_model.set_weights(model.get_weights())
146174

147-
tf.summary.scalar("reward",data=sum_reward,step=n_episode)
175+
tf.summary.scalar("reward vs training step",data=reward,step=n_step)

0 commit comments

Comments
 (0)