Skip to content

Commit 3534e1c

Browse files
author
xyliao
committed
finish gym
1 parent c48c020 commit 3534e1c

File tree

4 files changed

+154
-2
lines changed

4 files changed

+154
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Learn Deep Learning with PyTorch
6565

6666
- Chapter 7: 深度强化学习
6767
- [Q Learning](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/q-learning-intro.ipynb)
68-
- Open AI gym
68+
- [Open AI gym](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/open_ai_gym.ipynb)
6969
- [Deep Q-networks](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter7_RL/dqn.ipynb)
7070

7171
- Chapter 8: PyTorch高级

chapter7_RL/dqn.ipynb

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"\n",
1010
"一个非常简单的办法就是使用深度学习来解决这个问题,所以出现了一种新的网络,叫做 Deep Q Networks,将 Q learning 和 神经网络结合在了一起,对于每一个 state,我们都可以使用神经网络来计算对应动作的值,就不在需要建立一张表格,而且网络更新比表格更新更有效率,获取结果也更加高效。\n",
1111
"\n",
12+
"![](https://ws4.sinaimg.cn/large/006tKfTcgy1fni66at6jbj30xo0g1jut.jpg)\n",
13+
"\n",
1214
"下面我们使用 open ai gym 环境中的 CartPole 来尝试实现一个简单的 DQN。"
1315
]
1416
},
@@ -400,7 +402,9 @@
400402
"cell_type": "markdown",
401403
"metadata": {},
402404
"source": [
403-
"我们画出 reward 的曲线,可以发现奖励在不断变多,说明我们的 agent 学得越来越好,同时我们也可以实实在在地看到 agent 玩得怎么样,gym 提供了可视化的过程,但是 notebook 里面没有办法显示,我们可以使用运行 `dqn.py` 来看到 agent 玩的可视化视频。"
405+
"我们画出 reward 的曲线,可以发现奖励在不断变多,说明我们的 agent 学得越来越好,同时我们也可以实实在在地看到 agent 玩得怎么样,gym 提供了可视化的过程,但是 notebook 里面没有办法显示,我们可以使用运行 `dqn.py` 来看到 agent 玩的可视化视频。\n",
406+
"\n",
407+
"另外,我们这里只使用了简单的多层神经网络来作为 dqn 的网络结构,网络的输入是杆的位置信息和角度等等,我们当然可以使用更加一般的输入,比如说每个状态都是一个图片的输入,那么这种方式更具有一般性,实现上几乎是一模一样,只需要改一改网络结构,同时 gym 中也可以得到每个屏幕的输出,具体可以看看 pytorch 的[官方例子](http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#)。"
404408
]
405409
}
406410
],

chapter7_RL/mount-car.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
3+
import gym
4+
5+
n_states = 40 # 取样 40 个状态
6+
iter_max = 10000
7+
8+
initial_lr = 1.0 # Learning rate
9+
min_lr = 0.003
10+
gamma = 1.0
11+
t_max = 10000
12+
eps = 0.02
13+
14+
15+
def run_episode(env, policy=None, render=False):
16+
obs = env.reset()
17+
total_reward = 0
18+
step_idx = 0
19+
for _ in range(t_max):
20+
if render:
21+
env.render()
22+
if policy is None: # 如果没有策略,就随机取样
23+
action = env.action_space.sample()
24+
else:
25+
a, b = obs_to_state(env, obs)
26+
action = policy[a][b]
27+
obs, reward, done, _ = env.step(action)
28+
total_reward += gamma ** step_idx * reward
29+
step_idx += 1
30+
if done:
31+
break
32+
return total_reward
33+
34+
35+
def obs_to_state(env, obs):
36+
"""
37+
将观察的连续环境映射到离散的输入的状态
38+
"""
39+
env_low = env.observation_space.low
40+
env_high = env.observation_space.high
41+
env_dx = (env_high - env_low) / n_states
42+
a = int((obs[0] - env_low[0]) / env_dx[0])
43+
b = int((obs[1] - env_low[1]) / env_dx[1])
44+
return a, b
45+
46+
47+
if __name__ == '__main__':
48+
env_name = 'MountainCar-v0'
49+
env = gym.make(env_name)
50+
env.seed(0)
51+
np.random.seed(0)
52+
print('----- using Q Learning -----')
53+
q_table = np.zeros((n_states, n_states, 3))
54+
for i in range(iter_max):
55+
obs = env.reset()
56+
total_reward = 0
57+
## eta: 每一步学习率都不断减小
58+
eta = max(min_lr, initial_lr * (0.85 ** (i // 100)))
59+
for j in range(t_max):
60+
x, y = obs_to_state(env, obs)
61+
if np.random.uniform(0, 1) < eps: # greedy 贪心算法
62+
action = np.random.choice(env.action_space.n)
63+
else:
64+
logits = q_table[x, y, :]
65+
logits_exp = np.exp(logits)
66+
probs = logits_exp / np.sum(logits_exp) # 算出三个动作的概率
67+
action = np.random.choice(env.action_space.n, p=probs) # 依概率来选择动作
68+
obs, reward, done, _ = env.step(action)
69+
total_reward += reward
70+
# 更新 q 表
71+
x_, y_ = obs_to_state(env, obs)
72+
q_table[x, y, action] = q_table[x, y, action] + eta * (
73+
reward + gamma * np.max(q_table[x_, y_, :]) -
74+
q_table[x, y, action])
75+
if done:
76+
break
77+
if i % 100 == 0:
78+
print('Iteration #%d -- Total reward = %d.' % (i + 1,
79+
total_reward))
80+
solution_policy = np.argmax(q_table, axis=2) # 在 q 表中每个状态下都取最大的值得动作
81+
solution_policy_scores = [
82+
run_episode(env, solution_policy, False) for _ in range(100)
83+
]
84+
print("Average score of solution = ", np.mean(solution_policy_scores))
85+
# Animate it
86+
run_episode(env, solution_policy, True)

chapter7_RL/open_ai_gym.ipynb

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"collapsed": true
7+
},
8+
"source": [
9+
"# Gym 介绍\n",
10+
"前面我们简单的介绍了强化学习的例子,从这个例子可以发现,构建强化学习的环境非常麻烦,需要耗费我们大量的时间,这个时候我们可以使用一个开源的工具,叫做 gym,是由 open ai 开发的。\n",
11+
"\n",
12+
"在这个库中从简单的走格子到毁灭战士,提供了各种各样的游戏环境可以让大家放自己的 AI 进去玩耍。取名叫 gym 也很有意思,可以想象一群 AI 在健身房里各种锻炼,磨练技术。\n",
13+
"\n",
14+
"使用起来也非常方便,首先在终端内输入如下代码进行安装。\n",
15+
"\n",
16+
"```\n",
17+
"# Github源\n",
18+
"git clone https://github.com/openai/gym\n",
19+
"cd gym\n",
20+
"pip install -e .[all]\n",
21+
"\n",
22+
"# 直接下载gym包\n",
23+
"pip install gym[all]\n",
24+
"```\n",
25+
"\n",
26+
"我们可以访问这个页面看到 gym 所[包含的环境和介绍](https://github.com/openai/gym/wiki)。"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"在上面的环境页面,可以 gym 内置了很多环境,我们可以使用前面讲过的 q learning 尝试一个 gym 中的小例子,[mountain car](https://github.com/openai/gym/wiki/MountainCar-v0)。在 mounttain car,我们能够观察到环境中小车的位置,也就是坐标,我们能够采取的动作是向左或者向右。\n",
34+
"\n",
35+
"为了使用 q learning,我们必须要建立 q 表,而这里的状态空间是连续不可数的,所以我们需要离散化连续空间,将 x 坐标和 y 坐标都平均分成很多份,具体的实现可以运行 `mount-car.py` 看看结果。\n",
36+
"\n",
37+
"如果运行完之后,可以看到 q 表的收敛非常慢,reward 一直都很难变化,我们需要很久才能将小车推到终点,这个时候我们需要一个更加强大的武器,那就 deep q network。"
38+
]
39+
}
40+
],
41+
"metadata": {
42+
"kernelspec": {
43+
"display_name": "Python 3",
44+
"language": "python",
45+
"name": "python3"
46+
},
47+
"language_info": {
48+
"codemirror_mode": {
49+
"name": "ipython",
50+
"version": 3
51+
},
52+
"file_extension": ".py",
53+
"mimetype": "text/x-python",
54+
"name": "python",
55+
"nbconvert_exporter": "python",
56+
"pygments_lexer": "ipython3",
57+
"version": "3.6.3"
58+
}
59+
},
60+
"nbformat": 4,
61+
"nbformat_minor": 2
62+
}

0 commit comments

Comments
 (0)